Save and load hparams from checkpoints
This commit is contained in:
parent
e7c12d936e
commit
46e549c604
|
@ -1,4 +1,5 @@
|
||||||
import warnings
|
import warnings
|
||||||
|
from argparse import Namespace
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -177,6 +178,36 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_from_checkpoint(cls, checkpoint_path):
|
||||||
|
"""
|
||||||
|
Primary way of loading model from a checkpoint
|
||||||
|
:param checkpoint_path:
|
||||||
|
:param map_location: dic for mapping storage {'cuda:1':'cuda:0'}
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# load on CPU only to avoid OOM issues
|
||||||
|
# then its up to user to put back on GPUs
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
|
||||||
|
try:
|
||||||
|
ckpt_hparams = checkpoint['hparams']
|
||||||
|
except KeyError:
|
||||||
|
raise IOError(
|
||||||
|
"Checkpoint does not contain hyperparameters. Are your model hyperparameters stored"
|
||||||
|
"in self.hparams?"
|
||||||
|
)
|
||||||
|
hparams = Namespace(**ckpt_hparams)
|
||||||
|
|
||||||
|
# load the state_dict on the model automatically
|
||||||
|
model = cls(hparams)
|
||||||
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
|
|
||||||
|
# give model a chance to load something
|
||||||
|
model.on_load_checkpoint(checkpoint)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
def summarize(self, mode):
|
def summarize(self, mode):
|
||||||
model_summary = ModelSummary(self, mode=mode)
|
model_summary = ModelSummary(self, mode=mode)
|
||||||
print(model_summary)
|
print(model_summary)
|
||||||
|
|
|
@ -172,9 +172,10 @@ class TrainerIOMixin(object):
|
||||||
|
|
||||||
checkpoint['lr_schedulers'] = lr_schedulers
|
checkpoint['lr_schedulers'] = lr_schedulers
|
||||||
|
|
||||||
# add the state_dict from the model
|
# add the hparams and state_dict from the model
|
||||||
model = self.get_model()
|
model = self.get_model()
|
||||||
checkpoint['state_dict'] = model.state_dict()
|
checkpoint['state_dict'] = model.state_dict()
|
||||||
|
checkpoint['hparams'] = vars(model.hparams)
|
||||||
|
|
||||||
# give the model a chance to add a few things
|
# give the model a chance to add a few things
|
||||||
model.on_save_checkpoint(checkpoint)
|
model.on_save_checkpoint(checkpoint)
|
||||||
|
|
|
@ -402,6 +402,47 @@ def test_running_test_pretrained_model():
|
||||||
clear_save_dir()
|
clear_save_dir()
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_model_from_checkpoint():
|
||||||
|
reset_seed()
|
||||||
|
|
||||||
|
"""Verify test() on pretrained model"""
|
||||||
|
hparams = get_hparams()
|
||||||
|
model = LightningTestModel(hparams)
|
||||||
|
|
||||||
|
save_dir = init_save_dir()
|
||||||
|
|
||||||
|
trainer_options = dict(
|
||||||
|
show_progress_bar=False,
|
||||||
|
max_nb_epochs=1,
|
||||||
|
train_percent_check=0.4,
|
||||||
|
val_percent_check=0.2,
|
||||||
|
checkpoint_callback=True,
|
||||||
|
logger=False,
|
||||||
|
default_save_path=save_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
# fit model
|
||||||
|
trainer = Trainer(**trainer_options)
|
||||||
|
result = trainer.fit(model)
|
||||||
|
|
||||||
|
# correct result and ok accuracy
|
||||||
|
assert result == 1, 'training failed to complete'
|
||||||
|
pretrained_model = LightningTestModel.load_from_checkpoint(
|
||||||
|
os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt")
|
||||||
|
)
|
||||||
|
|
||||||
|
# test that hparams loaded correctly
|
||||||
|
for k, v in vars(hparams).items():
|
||||||
|
assert getattr(pretrained_model.hparams, k) == v
|
||||||
|
|
||||||
|
new_trainer = Trainer(**trainer_options)
|
||||||
|
new_trainer.test(pretrained_model)
|
||||||
|
|
||||||
|
# test we have good test accuracy
|
||||||
|
assert_ok_test_acc(new_trainer)
|
||||||
|
clear_save_dir()
|
||||||
|
|
||||||
|
|
||||||
def test_running_test_pretrained_model_dp():
|
def test_running_test_pretrained_model_dp():
|
||||||
reset_seed()
|
reset_seed()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue