Merge branch 'master' of https://github.com/williamFalcon/pytorch-lightning
This commit is contained in:
commit
3a3ac73963
|
@ -1,3 +1,3 @@
|
|||
rm -rf tests/save_dir*
|
||||
rm -rf tests/mlruns_9964541/mlruns/
|
||||
rm -rf tests/mlruns_*
|
||||
coverage run --source pytorch_lightning -m py.test pytorch_lightning tests examples -v --doctest-modules
|
||||
|
|
|
@ -35,7 +35,7 @@ class ModelSummary(object):
|
|||
out_sizes = []
|
||||
input_ = self.model.example_input_array
|
||||
|
||||
if self.model.use_ddp:
|
||||
if self.model.use_ddp or self.model.use_dp:
|
||||
input_ = input_.cuda(0)
|
||||
|
||||
if self.model.trainer.use_amp:
|
||||
|
|
|
@ -130,17 +130,16 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def load_from_metrics(cls, weights_path, tags_csv, on_gpu):
|
||||
def load_from_metrics(cls, weights_path, tags_csv):
|
||||
"""
|
||||
Primary way of loading model from csv weights path
|
||||
:param weights_path:
|
||||
:param tags_csv:
|
||||
:param on_gpu:
|
||||
:param map_location: dic for mapping storage {'cuda:1':'cuda:0'}
|
||||
:return:
|
||||
"""
|
||||
hparams = load_hparams_from_tags_csv(tags_csv)
|
||||
hparams.__setattr__('on_gpu', on_gpu)
|
||||
hparams.__setattr__('on_gpu', False)
|
||||
|
||||
# load on CPU only to avoid OOM issues
|
||||
# then its up to user to put back on GPUs
|
||||
|
|
|
@ -128,8 +128,8 @@ def test_dp_resume():
|
|||
dp_model = new_trainer.model
|
||||
dp_model.eval()
|
||||
|
||||
for dataloader in trainer.get_train_dataloader():
|
||||
run_prediction(dataloader, dp_model, dp=True)
|
||||
dataloader = trainer.get_train_dataloader()
|
||||
run_prediction(dataloader, dp_model, dp=True)
|
||||
|
||||
# new model
|
||||
model = LightningTestModel(hparams)
|
||||
|
@ -593,7 +593,7 @@ def test_no_val_module():
|
|||
tags_path = logger.experiment.get_data_path(logger.experiment.name, logger.experiment.version)
|
||||
tags_path = os.path.join(tags_path, 'meta_tags.csv')
|
||||
model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path,
|
||||
tags_csv=tags_path, on_gpu=False)
|
||||
tags_csv=tags_path)
|
||||
model_2.eval()
|
||||
|
||||
# make prediction
|
||||
|
@ -1448,8 +1448,8 @@ def load_model(exp, save_dir, on_gpu, module_class=LightningTemplateModel):
|
|||
weights_dir = os.path.join(save_dir, checkpoints[0])
|
||||
|
||||
trained_model = module_class.load_from_metrics(weights_path=weights_dir,
|
||||
tags_csv=tags_path,
|
||||
on_gpu=on_gpu)
|
||||
tags_csv=tags_path
|
||||
)
|
||||
|
||||
assert trained_model is not None, 'loading model failed'
|
||||
|
||||
|
|
Loading…
Reference in New Issue