Save and load hparams from checkpoints

This commit is contained in:
NicEggert 2019-10-22 15:48:25 -05:00
parent e7c12d936e
commit 46e549c604
3 changed files with 74 additions and 1 deletions

View File

@ -1,4 +1,5 @@
import warnings import warnings
from argparse import Namespace
import torch import torch
@ -177,6 +178,36 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
return model return model
@classmethod
def load_from_checkpoint(cls, checkpoint_path):
"""
Primary way of loading model from a checkpoint
:param checkpoint_path:
:param map_location: dic for mapping storage {'cuda:1':'cuda:0'}
:return:
"""
# load on CPU only to avoid OOM issues
# then its up to user to put back on GPUs
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
try:
ckpt_hparams = checkpoint['hparams']
except KeyError:
raise IOError(
"Checkpoint does not contain hyperparameters. Are your model hyperparameters stored"
"in self.hparams?"
)
hparams = Namespace(**ckpt_hparams)
# load the state_dict on the model automatically
model = cls(hparams)
model.load_state_dict(checkpoint['state_dict'])
# give model a chance to load something
model.on_load_checkpoint(checkpoint)
return model
def summarize(self, mode): def summarize(self, mode):
model_summary = ModelSummary(self, mode=mode) model_summary = ModelSummary(self, mode=mode)
print(model_summary) print(model_summary)

View File

@ -172,9 +172,10 @@ class TrainerIOMixin(object):
checkpoint['lr_schedulers'] = lr_schedulers checkpoint['lr_schedulers'] = lr_schedulers
# add the state_dict from the model # add the hparams and state_dict from the model
model = self.get_model() model = self.get_model()
checkpoint['state_dict'] = model.state_dict() checkpoint['state_dict'] = model.state_dict()
checkpoint['hparams'] = vars(model.hparams)
# give the model a chance to add a few things # give the model a chance to add a few things
model.on_save_checkpoint(checkpoint) model.on_save_checkpoint(checkpoint)

View File

@ -402,6 +402,47 @@ def test_running_test_pretrained_model():
clear_save_dir() clear_save_dir()
def test_load_model_from_checkpoint():
reset_seed()
"""Verify test() on pretrained model"""
hparams = get_hparams()
model = LightningTestModel(hparams)
save_dir = init_save_dir()
trainer_options = dict(
show_progress_bar=False,
max_nb_epochs=1,
train_percent_check=0.4,
val_percent_check=0.2,
checkpoint_callback=True,
logger=False,
default_save_path=save_dir
)
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
# correct result and ok accuracy
assert result == 1, 'training failed to complete'
pretrained_model = LightningTestModel.load_from_checkpoint(
os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt")
)
# test that hparams loaded correctly
for k, v in vars(hparams).items():
assert getattr(pretrained_model.hparams, k) == v
new_trainer = Trainer(**trainer_options)
new_trainer.test(pretrained_model)
# test we have good test accuracy
assert_ok_test_acc(new_trainer)
clear_save_dir()
def test_running_test_pretrained_model_dp(): def test_running_test_pretrained_model_dp():
reset_seed() reset_seed()