From ca73b70d15bc8db3f57c1fd2d3bf152e6e1d7c4e Mon Sep 17 00:00:00 2001 From: Hao Sheng Date: Sat, 14 Dec 2019 20:24:46 -0800 Subject: [PATCH] fix of issue 600 (#625) --- pytorch_lightning/core/lightning.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index ef2e4d4cc7..d2ba6c8a0b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -930,7 +930,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): return None @classmethod - def load_from_metrics(cls, weights_path, tags_csv): + def load_from_metrics(cls, weights_path, tags_csv, map_location=None): """Primary way of loading model from csv weights path. :param str weights_path: Path to a PyTorch checkpoint @@ -975,9 +975,10 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): hparams = load_hparams_from_tags_csv(tags_csv) hparams.__setattr__('on_gpu', False) - # load on CPU only to avoid OOM issues - # then its up to user to put back on GPUs - checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage) + if map_location is not None: + checkpoint = torch.load(weights_path, map_location=map_location) + else: + checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage) # load the state_dict on the model automatically model = cls(hparams) @@ -989,7 +990,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): return model @classmethod - def load_from_checkpoint(cls, checkpoint_path): + def load_from_checkpoint(cls, checkpoint_path, map_location=None): """ Primary way of loading model from a checkpoint :param checkpoint_path: @@ -997,9 +998,11 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): :return: """ - # load on CPU only to avoid OOM issues - # then its up to user to put back on GPUs - checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + if map_location is not None: + checkpoint = torch.load(checkpoint_path, map_location=map_location) + else: + checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + try: ckpt_hparams = checkpoint['hparams'] except KeyError: