From 05cea3ff8bd1210806c55397698ae39017fc3bc8 Mon Sep 17 00:00:00 2001 From: Nic Eggert Date: Wed, 23 Oct 2019 03:48:24 -0500 Subject: [PATCH] Save / Load Hyperparameters with checkpoint (#415) * Save and load hparams from checkpoints * Update docs * Add warning when not saving hparams * Missing import * Update .run_local_tests.sh * Update lm_test_module_mixins.py * Update lightning_module_template.py --- .run_local_tests.sh | 1 + docs/LightningModule/methods.md | 21 +++++++++- .../lightning_module_template.py | 2 +- pytorch_lightning/root_module/root_module.py | 31 ++++++++++++++ .../testing/lm_test_module_mixins.py | 4 +- pytorch_lightning/trainer/trainer_io.py | 10 ++++- tests/test_models.py | 41 +++++++++++++++++++ 7 files changed, 104 insertions(+), 6 deletions(-) diff --git a/.run_local_tests.sh b/.run_local_tests.sh index 1f61fb3be5..57d40c26b0 100644 --- a/.run_local_tests.sh +++ b/.run_local_tests.sh @@ -3,5 +3,6 @@ rm -rf _ckpt_* rm -rf tests/save_dir* rm -rf tests/mlruns_* rm -rf tests/tests/* +rm -rf lightning_logs coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --doctest-modules coverage report -m diff --git a/docs/LightningModule/methods.md b/docs/LightningModule/methods.md index cb96ea7a1d..a4325938a9 100644 --- a/docs/LightningModule/methods.md +++ b/docs/LightningModule/methods.md @@ -10,8 +10,25 @@ model.freeze() --- ### load_from_metrics -This is the easiest/fastest way which uses the meta_tags.csv file from test-tube to rebuild the model. -The meta_tags.csv file can be found in the test-tube experiment save_dir. +This is the easiest/fastest way which loads hyperparameters and weights from a checkpoint, +such as the one saved by the `ModelCheckpoint` callback + +```{.python} +pretrained_model = MyLightningModule.load_from_checkpoint( + checkpoint_path='/path/to/pytorch_checkpoint.ckpt' +) + +# predict +pretrained_model.eval() +pretrained_model.freeze() +y_hat = pretrained_model(x) +``` + +--- +### load_from_metrics +If you're using test tube, there is an alternate method which uses the meta_tags.csv +file from test-tube to rebuild the model. The meta_tags.csv file can be found in the +test-tube experiment save_dir. ```{.python} pretrained_model = MyLightningModule.load_from_metrics( diff --git a/pl_examples/basic_examples/lightning_module_template.py b/pl_examples/basic_examples/lightning_module_template.py index f773fc81a3..6b58ab46ef 100644 --- a/pl_examples/basic_examples/lightning_module_template.py +++ b/pl_examples/basic_examples/lightning_module_template.py @@ -158,7 +158,7 @@ class LightningTemplateModel(LightningModule): val_loss = output['val_loss'] # reduce manually when using dp - if self.trainer.use_dp: + if self.trainer.use_dp or self.trainer.use_ddp2: val_loss = torch.mean(val_loss) val_loss_mean += val_loss diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index a4f4fb6003..d037b038ab 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -1,4 +1,5 @@ import warnings +from argparse import Namespace import torch @@ -177,6 +178,36 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): return model + @classmethod + def load_from_checkpoint(cls, checkpoint_path): + """ + Primary way of loading model from a checkpoint + :param checkpoint_path: + :param map_location: dic for mapping storage {'cuda:1':'cuda:0'} + :return: + """ + + # load on CPU only to avoid OOM issues + # then its up to user to put back on GPUs + checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) + try: + ckpt_hparams = checkpoint['hparams'] + except KeyError: + raise IOError( + "Checkpoint does not contain hyperparameters. Are your model hyperparameters stored" + "in self.hparams?" + ) + hparams = Namespace(**ckpt_hparams) + + # load the state_dict on the model automatically + model = cls(hparams) + model.load_state_dict(checkpoint['state_dict']) + + # give model a chance to load something + model.on_load_checkpoint(checkpoint) + + return model + def summarize(self, mode): model_summary = ModelSummary(self, mode=mode) print(model_summary) diff --git a/pytorch_lightning/testing/lm_test_module_mixins.py b/pytorch_lightning/testing/lm_test_module_mixins.py index 3815e84e28..b568676685 100644 --- a/pytorch_lightning/testing/lm_test_module_mixins.py +++ b/pytorch_lightning/testing/lm_test_module_mixins.py @@ -80,13 +80,13 @@ class LightningValidationMixin(LightningValidationStepMixin): val_loss = output['val_loss'] # reduce manually when using dp - if self.trainer.use_dp: + if self.trainer.use_dp or self.trainer.use_ddp2: val_loss = torch.mean(val_loss) val_loss_mean += val_loss # reduce manually when using dp val_acc = output['val_acc'] - if self.trainer.use_dp: + if self.trainer.use_dp or self.trainer.use_ddp2: val_acc = torch.mean(val_acc) val_acc_mean += val_acc diff --git a/pytorch_lightning/trainer/trainer_io.py b/pytorch_lightning/trainer/trainer_io.py index c19dfedcef..aa40e21de1 100644 --- a/pytorch_lightning/trainer/trainer_io.py +++ b/pytorch_lightning/trainer/trainer_io.py @@ -1,6 +1,7 @@ import os import re import signal +import warnings from subprocess import call import torch @@ -172,9 +173,16 @@ class TrainerIOMixin(object): checkpoint['lr_schedulers'] = lr_schedulers - # add the state_dict from the model + # add the hparams and state_dict from the model model = self.get_model() checkpoint['state_dict'] = model.state_dict() + if hasattr(model, "hparams"): + checkpoint['hparams'] = vars(model.hparams) + else: + warnings.warn( + "Did not find hyperparameters at model.hparams. Saving checkpoint without" + " hyperparameters" + ) # give the model a chance to add a few things model.on_save_checkpoint(checkpoint) diff --git a/tests/test_models.py b/tests/test_models.py index ef99b96822..612668131e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -402,6 +402,47 @@ def test_running_test_pretrained_model(): clear_save_dir() +def test_load_model_from_checkpoint(): + reset_seed() + + """Verify test() on pretrained model""" + hparams = get_hparams() + model = LightningTestModel(hparams) + + save_dir = init_save_dir() + + trainer_options = dict( + show_progress_bar=False, + max_nb_epochs=1, + train_percent_check=0.4, + val_percent_check=0.2, + checkpoint_callback=True, + logger=False, + default_save_path=save_dir + ) + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + # correct result and ok accuracy + assert result == 1, 'training failed to complete' + pretrained_model = LightningTestModel.load_from_checkpoint( + os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt") + ) + + # test that hparams loaded correctly + for k, v in vars(hparams).items(): + assert getattr(pretrained_model.hparams, k) == v + + new_trainer = Trainer(**trainer_options) + new_trainer.test(pretrained_model) + + # test we have good test accuracy + assert_ok_test_acc(new_trainer) + clear_save_dir() + + def test_running_test_pretrained_model_dp(): reset_seed()