From 46e549c6041ef8fbf7c029e5c116db4c92969475 Mon Sep 17 00:00:00 2001 From: NicEggert Date: Tue, 22 Oct 2019 15:48:25 -0500 Subject: [PATCH] Save and load hparams from checkpoints --- pytorch_lightning/root_module/root_module.py | 31 +++++++++++++++ pytorch_lightning/trainer/trainer_io.py | 3 +- tests/test_models.py | 41 ++++++++++++++++++++ 3 files changed, 74 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index a4f4fb6003..d037b038ab 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -1,4 +1,5 @@ import warnings +from argparse import Namespace import torch @@ -177,6 +178,36 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): return model + @classmethod + def load_from_checkpoint(cls, checkpoint_path): + """ + Primary way of loading model from a checkpoint + :param checkpoint_path: + :param map_location: dic for mapping storage {'cuda:1':'cuda:0'} + :return: + """ + + # load on CPU only to avoid OOM issues + # then its up to user to put back on GPUs + checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + try: + ckpt_hparams = checkpoint['hparams'] + except KeyError: + raise IOError( + "Checkpoint does not contain hyperparameters. Are your model hyperparameters stored" + "in self.hparams?" + ) + hparams = Namespace(**ckpt_hparams) + + # load the state_dict on the model automatically + model = cls(hparams) + model.load_state_dict(checkpoint['state_dict']) + + # give model a chance to load something + model.on_load_checkpoint(checkpoint) + + return model + def summarize(self, mode): model_summary = ModelSummary(self, mode=mode) print(model_summary) diff --git a/pytorch_lightning/trainer/trainer_io.py b/pytorch_lightning/trainer/trainer_io.py index c19dfedcef..30afeddc09 100644 --- a/pytorch_lightning/trainer/trainer_io.py +++ b/pytorch_lightning/trainer/trainer_io.py @@ -172,9 +172,10 @@ class TrainerIOMixin(object): checkpoint['lr_schedulers'] = lr_schedulers - # add the state_dict from the model + # add the hparams and state_dict from the model model = self.get_model() checkpoint['state_dict'] = model.state_dict() + checkpoint['hparams'] = vars(model.hparams) # give the model a chance to add a few things model.on_save_checkpoint(checkpoint) diff --git a/tests/test_models.py b/tests/test_models.py index ef99b96822..612668131e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -402,6 +402,47 @@ def test_running_test_pretrained_model(): clear_save_dir() +def test_load_model_from_checkpoint(): + reset_seed() + + """Verify test() on pretrained model""" + hparams = get_hparams() + model = LightningTestModel(hparams) + + save_dir = init_save_dir() + + trainer_options = dict( + show_progress_bar=False, + max_nb_epochs=1, + train_percent_check=0.4, + val_percent_check=0.2, + checkpoint_callback=True, + logger=False, + default_save_path=save_dir + ) + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + # correct result and ok accuracy + assert result == 1, 'training failed to complete' + pretrained_model = LightningTestModel.load_from_checkpoint( + os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt") + ) + + # test that hparams loaded correctly + for k, v in vars(hparams).items(): + assert getattr(pretrained_model.hparams, k) == v + + new_trainer = Trainer(**trainer_options) + new_trainer.test(pretrained_model) + + # test we have good test accuracy + assert_ok_test_acc(new_trainer) + clear_save_dir() + + def test_running_test_pretrained_model_dp(): reset_seed()