fix of issue 600 (#625)

This commit is contained in:
Hao Sheng 2019-12-14 20:24:46 -08:00 committed by William Falcon
parent 3dd0b8c186
commit ca73b70d15
1 changed files with 11 additions and 8 deletions

View File

@ -930,7 +930,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
return None
@classmethod
def load_from_metrics(cls, weights_path, tags_csv):
def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
"""Primary way of loading model from csv weights path.
:param str weights_path: Path to a PyTorch checkpoint
@ -975,9 +975,10 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
hparams = load_hparams_from_tags_csv(tags_csv)
hparams.__setattr__('on_gpu', False)
# load on CPU only to avoid OOM issues
# then its up to user to put back on GPUs
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)
if map_location is not None:
checkpoint = torch.load(weights_path, map_location=map_location)
else:
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)
# load the state_dict on the model automatically
model = cls(hparams)
@ -989,7 +990,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
return model
@classmethod
def load_from_checkpoint(cls, checkpoint_path):
def load_from_checkpoint(cls, checkpoint_path, map_location=None):
"""
Primary way of loading model from a checkpoint
:param checkpoint_path:
@ -997,9 +998,11 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
:return:
"""
# load on CPU only to avoid OOM issues
# then its up to user to put back on GPUs
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
if map_location is not None:
checkpoint = torch.load(checkpoint_path, map_location=map_location)
else:
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
try:
ckpt_hparams = checkpoint['hparams']
except KeyError: