This commit is contained in:
William Falcon 2019-10-04 16:56:05 -04:00
commit 3a3ac73963
4 changed files with 9 additions and 10 deletions

View File

@ -1,3 +1,3 @@
rm -rf tests/save_dir*
rm -rf tests/mlruns_9964541/mlruns/
rm -rf tests/mlruns_*
coverage run --source pytorch_lightning -m py.test pytorch_lightning tests examples -v --doctest-modules

View File

@ -35,7 +35,7 @@ class ModelSummary(object):
out_sizes = []
input_ = self.model.example_input_array
if self.model.use_ddp:
if self.model.use_ddp or self.model.use_dp:
input_ = input_.cuda(0)
if self.model.trainer.use_amp:

View File

@ -130,17 +130,16 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
return None
@classmethod
def load_from_metrics(cls, weights_path, tags_csv, on_gpu):
def load_from_metrics(cls, weights_path, tags_csv):
"""
Primary way of loading model from csv weights path
:param weights_path:
:param tags_csv:
:param on_gpu:
:param map_location: dic for mapping storage {'cuda:1':'cuda:0'}
:return:
"""
hparams = load_hparams_from_tags_csv(tags_csv)
hparams.__setattr__('on_gpu', on_gpu)
hparams.__setattr__('on_gpu', False)
# load on CPU only to avoid OOM issues
# then its up to user to put back on GPUs

View File

@ -128,8 +128,8 @@ def test_dp_resume():
dp_model = new_trainer.model
dp_model.eval()
for dataloader in trainer.get_train_dataloader():
run_prediction(dataloader, dp_model, dp=True)
dataloader = trainer.get_train_dataloader()
run_prediction(dataloader, dp_model, dp=True)
# new model
model = LightningTestModel(hparams)
@ -593,7 +593,7 @@ def test_no_val_module():
tags_path = logger.experiment.get_data_path(logger.experiment.name, logger.experiment.version)
tags_path = os.path.join(tags_path, 'meta_tags.csv')
model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path,
tags_csv=tags_path, on_gpu=False)
tags_csv=tags_path)
model_2.eval()
# make prediction
@ -1448,8 +1448,8 @@ def load_model(exp, save_dir, on_gpu, module_class=LightningTemplateModel):
weights_dir = os.path.join(save_dir, checkpoints[0])
trained_model = module_class.load_from_metrics(weights_path=weights_dir,
tags_csv=tags_path,
on_gpu=on_gpu)
tags_csv=tags_path
)
assert trained_model is not None, 'loading model failed'