fix of issue 600 (#625)
This commit is contained in:
parent
3dd0b8c186
commit
ca73b70d15
|
@ -930,7 +930,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def load_from_metrics(cls, weights_path, tags_csv):
|
||||
def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
|
||||
"""Primary way of loading model from csv weights path.
|
||||
|
||||
:param str weights_path: Path to a PyTorch checkpoint
|
||||
|
@ -975,9 +975,10 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
|||
hparams = load_hparams_from_tags_csv(tags_csv)
|
||||
hparams.__setattr__('on_gpu', False)
|
||||
|
||||
# load on CPU only to avoid OOM issues
|
||||
# then its up to user to put back on GPUs
|
||||
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)
|
||||
if map_location is not None:
|
||||
checkpoint = torch.load(weights_path, map_location=map_location)
|
||||
else:
|
||||
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage)
|
||||
|
||||
# load the state_dict on the model automatically
|
||||
model = cls(hparams)
|
||||
|
@ -989,7 +990,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
|||
return model
|
||||
|
||||
@classmethod
|
||||
def load_from_checkpoint(cls, checkpoint_path):
|
||||
def load_from_checkpoint(cls, checkpoint_path, map_location=None):
|
||||
"""
|
||||
Primary way of loading model from a checkpoint
|
||||
:param checkpoint_path:
|
||||
|
@ -997,9 +998,11 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
|||
:return:
|
||||
"""
|
||||
|
||||
# load on CPU only to avoid OOM issues
|
||||
# then its up to user to put back on GPUs
|
||||
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
|
||||
if map_location is not None:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
|
||||
|
||||
try:
|
||||
ckpt_hparams = checkpoint['hparams']
|
||||
except KeyError:
|
||||
|
|
Loading…
Reference in New Issue