From a5d1176cf6ef9e637144f980da11bbe63290c994 Mon Sep 17 00:00:00 2001 From: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Date: Fri, 28 Aug 2020 10:50:52 -0400 Subject: [PATCH] callback method for on_save_checkpoint (#2501) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * initial draft * fix test * Update pytorch_lightning/trainer/callback_hook.py Co-authored-by: Adrian Wälchli * fix tests * remove old code * untested upgrade script * document limitations * clean up and add tests * Update pytorch_lightning/trainer/training_io.py Co-authored-by: Adrian Wälchli * reflect PR comments * fix formatting * Update docs/source/callbacks.rst * clarify docs * revert change for loading checkpoints * small edits Co-authored-by: Adrian Wälchli --- docs/source/callbacks.rst | 12 ++++ pytorch_lightning/callbacks/base.py | 9 +++ pytorch_lightning/callbacks/early_stopping.py | 15 ++--- .../callbacks/model_checkpoint.py | 10 ++++ pytorch_lightning/trainer/callback_hook.py | 20 +++++++ pytorch_lightning/trainer/training_io.py | 59 ++++++------------- .../utilities/upgrade_checkpoint.py | 45 ++++++++++++++ tests/callbacks/test_early_stopping.py | 50 ++++++++-------- tests/callbacks/test_model_checkpoint.py | 30 ++++++---- tests/utilities/test_upgrade_checkpoint.py | 36 +++++++++++ 10 files changed, 199 insertions(+), 87 deletions(-) create mode 100644 pytorch_lightning/utilities/upgrade_checkpoint.py create mode 100644 tests/utilities/test_upgrade_checkpoint.py diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index c8bb615673..c3a85887b8 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -144,6 +144,18 @@ Lightning has a few built-in callbacks. ---------- +Persisting State +---------------- + +Some callbacks require internal state in order to function properly. You can optionally +choose to persist your callback's state as part of model checkpoint files using the callback hooks +:meth:`~pytorch_lightning.callbacks.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.Callback.on_load_checkpoint`. +However, you must follow two constraints: + +1. Your returned state must be able to be pickled. +2. You can only use one instance of that class in the Trainer callbacks list. We don't support persisting state for multiple callbacks of the same class. + + Best Practices -------------- The following are best practices when using/designing callbacks. diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 6a99e341be..1011e047b9 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -155,3 +155,12 @@ class Callback(abc.ABC): def on_keyboard_interrupt(self, trainer, pl_module): """Called when the training is interrupted by KeyboardInterrupt.""" + pass + + def on_save_checkpoint(self, trainer, pl_module): + """Called when saving a model checkpoint, use to persist state.""" + pass + + def on_load_checkpoint(self, checkpointed_state): + """Called when loading a model checkpoint, use to reload state.""" + pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index ff922ad6c7..5b3fc1c9db 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -19,8 +19,6 @@ Early Stopping Monitor a validation metric and stop training when it stops improving. """ -from copy import deepcopy - import numpy as np import torch import torch.distributed as dist @@ -126,7 +124,7 @@ class EarlyStopping(Callback): def monitor_op(self): return self.mode_dict[self.mode] - def state_dict(self): + def on_save_checkpoint(self, trainer, pl_module): return { 'wait_count': self.wait_count, 'stopped_epoch': self.stopped_epoch, @@ -134,12 +132,11 @@ class EarlyStopping(Callback): 'patience': self.patience } - def load_state_dict(self, state_dict): - state_dict = deepcopy(state_dict) - self.wait_count = state_dict['wait_count'] - self.stopped_epoch = state_dict['stopped_epoch'] - self.best_score = state_dict['best_score'] - self.patience = state_dict['patience'] + def on_load_checkpoint(self, checkpointed_state): + self.wait_count = checkpointed_state['wait_count'] + self.stopped_epoch = checkpointed_state['stopped_epoch'] + self.best_score = checkpointed_state['best_score'] + self.patience = checkpointed_state['patience'] def on_validation_end(self, trainer, pl_module): if trainer.running_sanity_check: diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f19e46f108..63dddd1aa1 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -426,3 +426,13 @@ class ModelCheckpoint(Callback): for cur_path in del_list: if cur_path != filepath: self._del_model(cur_path) + + def on_save_checkpoint(self, trainer, pl_module): + return { + 'best_model_score': self.best_model_score, + 'best_model_path': self.best_model_path, + } + + def on_load_checkpoint(self, checkpointed_state): + self.best_model_score = checkpointed_state['best_model_score'] + self.best_model_path = checkpointed_state['best_model_path'] \ No newline at end of file diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 421f468fec..27539dbd1f 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -13,6 +13,7 @@ # limitations under the License. from abc import ABC +from copy import deepcopy from typing import Callable, List from pytorch_lightning.callbacks import Callback @@ -189,3 +190,22 @@ class TrainerCallbackHookMixin(ABC): """Called when the training is interrupted by KeyboardInterrupt.""" for callback in self.callbacks: callback.on_keyboard_interrupt(self, self.get_model()) + + def on_save_checkpoint(self): + """Called when saving a model checkpoint.""" + callback_states = {} + for callback in self.callbacks: + callback_class = type(callback) + state = callback.on_save_checkpoint(self, self.get_model()) + if state: + callback_states[callback_class] = state + return callback_states + + def on_load_checkpoint(self, checkpoint): + """Called when loading a model checkpoint.""" + callback_states = checkpoint.get('callbacks') + for callback in self.callbacks: + state = callback_states.get(type(callback)) + if state: + state = deepcopy(state) + callback.on_load_checkpoint(state) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index fb8387a51c..f10eb96be1 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -109,16 +109,15 @@ import torch.distributed as torch_distrib import pytorch_lightning from pytorch_lightning import _logger as log -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.overrides.data_parallel import ( - LightningDistributedDataParallel, - LightningDataParallel, -) -from pytorch_lightning.utilities import rank_zero_warn, AMPType -from pytorch_lightning.utilities.cloud_io import gfile, makedirs -from pytorch_lightning.utilities.cloud_io import load as pl_load, atomic_save +from pytorch_lightning.overrides.data_parallel import LightningDataParallel, LightningDistributedDataParallel +from pytorch_lightning.utilities import AMPType, rank_zero_warn +from pytorch_lightning.utilities.cloud_io import atomic_save, gfile +from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.cloud_io import makedirs +from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS try: import torch_xla @@ -340,20 +339,9 @@ class TrainerIOMixin(ABC): if not weights_only: - # TODO support more generic way for callbacks to persist a state_dict in a checkpoint - checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] - early_stopping_callbacks = [c for c in self.callbacks if isinstance(c, EarlyStopping)] - - if checkpoint_callbacks: - # we add the official checkpoint callback to the end of the list - # extra user provided callbacks will not be persisted yet - checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE] = self.checkpoint_callback.best_model_score - checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH] = self.checkpoint_callback.best_model_path - - if early_stopping_callbacks and checkpoint_callbacks: - # we add the official early stopping callback to the end of the list - # extra user provided callbacks will not be persisted yet - checkpoint['early_stop_callback_state_dict'] = early_stopping_callbacks[-1].state_dict() + # save callbacks + callback_states = self.on_save_checkpoint() + checkpoint['callbacks'] = callback_states # save optimizers optimizer_states = [] @@ -424,25 +412,16 @@ class TrainerIOMixin(ABC): ' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.' ) - # TODO support more generic way for callbacks to load callback state_dicts - checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] - early_stopping_callbacks = [c for c in self.callbacks if isinstance(c, EarlyStopping)] + if any([key in checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]): + raise ValueError( + "The checkpoint you're attempting to load follows an" + " outdated schema. You can upgrade to the current schema by running" + " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`" + " where `model.ckpt` is your checkpoint file." + ) - if checkpoint_callbacks: - if ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE in checkpoint: - checkpoint_callbacks[-1].best_model_score = checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE] - else: - # Old naming until version 0.7.6 - rank_zero_warn( - 'Loading a checkpoint created with an old version of Lightning; ' - 'this will not be supported in the future.' - ) - checkpoint_callbacks[-1].best_model_score = checkpoint['checkpoint_callback_best'] - checkpoint_callbacks[-1].best_model_path = checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH] - - if early_stopping_callbacks: - state = checkpoint['early_stop_callback_state_dict'] - early_stopping_callbacks[-1].load_state_dict(state) + # load callback states + self.on_load_checkpoint(checkpoint) self.global_step = checkpoint['global_step'] self.current_epoch = checkpoint['epoch'] diff --git a/pytorch_lightning/utilities/upgrade_checkpoint.py b/pytorch_lightning/utilities/upgrade_checkpoint.py new file mode 100644 index 0000000000..c071dc97c2 --- /dev/null +++ b/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -0,0 +1,45 @@ +import argparse +from shutil import copyfile + +import torch + +from pytorch_lightning import _logger as log +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint + +KEYS_MAPPING = { + "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), + "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), + "checkpoint_callback_best": (ModelCheckpoint, "best_model_score"), + "early_stop_callback_wait": (EarlyStopping, "wait_count"), + "early_stop_callback_patience": (EarlyStopping, "patience"), +} + + +def upgrade_checkpoint(filepath): + checkpoint = torch.load(filepath) + checkpoint["callbacks"] = checkpoint.get("callbacks") or {} + + for key, new_path in KEYS_MAPPING.items(): + if key in checkpoint: + value = checkpoint[key] + callback_type, callback_key = new_path + checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {} + checkpoint["callbacks"][callback_type][callback_key] = value + del checkpoint[key] + + torch.save(checkpoint, filepath) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="Upgrade an old checkpoint to the current schema. \ + This will also save a backup of the original file." + ) + parser.add_argument("--file", help="filepath for a checkpoint to upgrade") + + args = parser.parse_args() + + log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.") + copyfile(args.file, args.file + ".bak") + upgrade_checkpoint(args.file) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index dedd1ef889..028f4f32ea 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -11,6 +11,23 @@ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from tests.base import EvalModelTemplate +class EarlyStoppingTestRestore(EarlyStopping): + # this class has to be defined outside the test function, otherwise we get pickle error + def __init__(self, expected_state=None): + super().__init__() + self.expected_state = expected_state + # cache the state for each epoch + self.saved_states = [] + + def on_train_start(self, trainer, pl_module): + if self.expected_state: + assert self.on_save_checkpoint(trainer, pl_module) == self.expected_state + + def on_validation_end(self, trainer, pl_module): + super().on_validation_end(trainer, pl_module) + self.saved_states.append(self.on_save_checkpoint(trainer, pl_module).copy()) + + def test_resume_early_stopping_from_checkpoint(tmpdir): """ Prevent regressions to bugs: @@ -18,27 +35,9 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): https://github.com/PyTorchLightning/pytorch-lightning/issues/1463 """ - class EarlyStoppingTestStore(EarlyStopping): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # cache the state for each epoch - self.saved_states = [] - - def on_validation_end(self, trainer, pl_module): - super().on_validation_end(trainer, pl_module) - self.saved_states.append(deepcopy(self.state_dict())) - - class EarlyStoppingTestRestore(EarlyStopping): - def __init__(self, expected_state): - super().__init__() - self.expected_state = expected_state - - def on_train_start(self, trainer, pl_module): - assert self.state_dict() == self.expected_state - model = EvalModelTemplate() checkpoint_callback = ModelCheckpoint(save_top_k=1) - early_stop_callback = EarlyStoppingTestStore() + early_stop_callback = EarlyStoppingTestRestore() trainer = Trainer( default_root_dir=tmpdir, checkpoint_callback=checkpoint_callback, @@ -52,9 +51,9 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): # ensure state is persisted properly checkpoint = torch.load(checkpoint_filepath) # the checkpoint saves "epoch + 1" - early_stop_callback_state = early_stop_callback.saved_states[checkpoint['epoch'] - 1] + early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1] assert 4 == len(early_stop_callback.saved_states) - assert checkpoint['early_stop_callback_state_dict'] == early_stop_callback_state + assert checkpoint["callbacks"][type(early_stop_callback)] == early_stop_callback_state # ensure state is reloaded properly (assertion in the callback) early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state) @@ -84,11 +83,10 @@ def test_early_stopping_no_extraneous_invocations(tmpdir): assert len(trainer.dev_debugger.early_stopping_history) == expected_count -@pytest.mark.parametrize('loss_values, patience, expected_stop_epoch', [ - ([6, 5, 5, 5, 5, 5], 3, 4), - ([6, 5, 4, 4, 3, 3], 1, 3), - ([6, 5, 6, 5, 5, 5], 3, 4), -]) +@pytest.mark.parametrize( + "loss_values, patience, expected_stop_epoch", + [([6, 5, 5, 5, 5, 5], 3, 4), ([6, 5, 4, 4, 3, 3], 1, 3), ([6, 5, 6, 5, 5, 5], 3, 4),], +) def test_early_stopping_patience(tmpdir, loss_values, patience, expected_stop_epoch): """Test to ensure that early stopping is not triggered before patience is exhausted.""" diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index 976fc887d3..3e07902833 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -15,7 +15,7 @@ from pytorch_lightning.loggers import TensorBoardLogger from tests.base import EvalModelTemplate -@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2]) +@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """ tutils.reset_seed() @@ -25,11 +25,11 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=checkpoint, overfit_batches=0.20, max_epochs=2) trainer.fit(model) - assert checkpoint.dirpath == tmpdir / trainer.logger.name / 'version_0' / 'checkpoints' + assert checkpoint.dirpath == tmpdir / trainer.logger.name / "version_0" / "checkpoints" @pytest.mark.parametrize( - 'logger_version,expected', [(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')], + "logger_version,expected", [(None, "version_0"), (1, "version_1"), ("awesome", "awesome")], ) def test_model_checkpoint_path(tmpdir, logger_version, expected): """Test that "version_" prefix is only added when logger's version is an integer""" @@ -86,7 +86,7 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir): num_epochs = 4 model_checkpoint = ModelCheckpointTestInvocations(expected_count=num_epochs, save_top_k=-1) trainer = Trainer( - distributed_backend='ddp_cpu', + distributed_backend="ddp_cpu", num_processes=2, default_root_dir=tmpdir, early_stop_callback=False, @@ -104,10 +104,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): num_epochs = 3 model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True) trainer = Trainer( - default_root_dir=tmpdir, - early_stop_callback=False, - checkpoint_callback=model_checkpoint, - max_epochs=num_epochs, + default_root_dir=tmpdir, early_stop_callback=False, checkpoint_callback=model_checkpoint, max_epochs=num_epochs, ) trainer.fit(model) path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) # epoch=3.ckpt @@ -115,15 +112,24 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): assert path_last_epoch != path_last ckpt_last_epoch = torch.load(path_last_epoch) ckpt_last = torch.load(path_last) - matching_keys = ( + + trainer_keys = ( "epoch", "global_step", - ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE, - ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH, ) - for key in matching_keys: + for key in trainer_keys: assert ckpt_last_epoch[key] == ckpt_last[key] + checkpoint_callback_keys = ( + "best_model_score", + "best_model_path", + ) + for key in checkpoint_callback_keys: + assert ( + ckpt_last_epoch["callbacks"][type(model_checkpoint)][key] + == ckpt_last_epoch["callbacks"][type(model_checkpoint)][key] + ) + # it is easier to load the model objects than to iterate over the raw dict of tensors model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch) model_last = EvalModelTemplate.load_from_checkpoint(path_last) diff --git a/tests/utilities/test_upgrade_checkpoint.py b/tests/utilities/test_upgrade_checkpoint.py new file mode 100644 index 0000000000..5f5ecd6b5f --- /dev/null +++ b/tests/utilities/test_upgrade_checkpoint.py @@ -0,0 +1,36 @@ +import pytest +import os + +import torch + +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.utilities.upgrade_checkpoint import upgrade_checkpoint + + +@pytest.mark.parametrize( + "old_checkpoint, new_checkpoint", + [ + ( + {"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}}, + ), + ( + {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}}, + ), + ( + {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": 'path'}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": 'path'}}}, + ), + ( + {"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4}, + {"epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}}, + ), + ], +) +def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): + filepath = os.path.join(tmpdir, "model.ckpt") + torch.save(old_checkpoint, filepath) + upgrade_checkpoint(filepath) + updated_checkpoint = torch.load(filepath) + assert updated_checkpoint == new_checkpoint