Save and load hparams from checkpoints
This commit is contained in:
parent
e7c12d936e
commit
46e549c604
|
@ -1,4 +1,5 @@
|
|||
import warnings
|
||||
from argparse import Namespace
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -177,6 +178,36 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
|||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def load_from_checkpoint(cls, checkpoint_path):
|
||||
"""
|
||||
Primary way of loading model from a checkpoint
|
||||
:param checkpoint_path:
|
||||
:param map_location: dic for mapping storage {'cuda:1':'cuda:0'}
|
||||
:return:
|
||||
"""
|
||||
|
||||
# load on CPU only to avoid OOM issues
|
||||
# then its up to user to put back on GPUs
|
||||
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
|
||||
try:
|
||||
ckpt_hparams = checkpoint['hparams']
|
||||
except KeyError:
|
||||
raise IOError(
|
||||
"Checkpoint does not contain hyperparameters. Are your model hyperparameters stored"
|
||||
"in self.hparams?"
|
||||
)
|
||||
hparams = Namespace(**ckpt_hparams)
|
||||
|
||||
# load the state_dict on the model automatically
|
||||
model = cls(hparams)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
||||
# give model a chance to load something
|
||||
model.on_load_checkpoint(checkpoint)
|
||||
|
||||
return model
|
||||
|
||||
def summarize(self, mode):
|
||||
model_summary = ModelSummary(self, mode=mode)
|
||||
print(model_summary)
|
||||
|
|
|
@ -172,9 +172,10 @@ class TrainerIOMixin(object):
|
|||
|
||||
checkpoint['lr_schedulers'] = lr_schedulers
|
||||
|
||||
# add the state_dict from the model
|
||||
# add the hparams and state_dict from the model
|
||||
model = self.get_model()
|
||||
checkpoint['state_dict'] = model.state_dict()
|
||||
checkpoint['hparams'] = vars(model.hparams)
|
||||
|
||||
# give the model a chance to add a few things
|
||||
model.on_save_checkpoint(checkpoint)
|
||||
|
|
|
@ -402,6 +402,47 @@ def test_running_test_pretrained_model():
|
|||
clear_save_dir()
|
||||
|
||||
|
||||
def test_load_model_from_checkpoint():
|
||||
reset_seed()
|
||||
|
||||
"""Verify test() on pretrained model"""
|
||||
hparams = get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
|
||||
save_dir = init_save_dir()
|
||||
|
||||
trainer_options = dict(
|
||||
show_progress_bar=False,
|
||||
max_nb_epochs=1,
|
||||
train_percent_check=0.4,
|
||||
val_percent_check=0.2,
|
||||
checkpoint_callback=True,
|
||||
logger=False,
|
||||
default_save_path=save_dir
|
||||
)
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
# correct result and ok accuracy
|
||||
assert result == 1, 'training failed to complete'
|
||||
pretrained_model = LightningTestModel.load_from_checkpoint(
|
||||
os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt")
|
||||
)
|
||||
|
||||
# test that hparams loaded correctly
|
||||
for k, v in vars(hparams).items():
|
||||
assert getattr(pretrained_model.hparams, k) == v
|
||||
|
||||
new_trainer = Trainer(**trainer_options)
|
||||
new_trainer.test(pretrained_model)
|
||||
|
||||
# test we have good test accuracy
|
||||
assert_ok_test_acc(new_trainer)
|
||||
clear_save_dir()
|
||||
|
||||
|
||||
def test_running_test_pretrained_model_dp():
|
||||
reset_seed()
|
||||
|
||||
|
|
Loading…
Reference in New Issue