callback method for on_save_checkpoint (#2501)

* initial draft

* fix test

* Update pytorch_lightning/trainer/callback_hook.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* fix tests

* remove old code

* untested upgrade script

* document limitations

* clean up and add tests

* Update pytorch_lightning/trainer/training_io.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* reflect PR comments

* fix formatting

* Update docs/source/callbacks.rst

* clarify docs

* revert change for loading checkpoints

* small edits

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Jeremy Jordan 2020-08-28 10:50:52 -04:00 committed by GitHub
parent 79375e6d0a
commit a5d1176cf6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 199 additions and 87 deletions

View File

@ -144,6 +144,18 @@ Lightning has a few built-in callbacks.
----------
Persisting State
----------------
Some callbacks require internal state in order to function properly. You can optionally
choose to persist your callback's state as part of model checkpoint files using the callback hooks
:meth:`~pytorch_lightning.callbacks.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.Callback.on_load_checkpoint`.
However, you must follow two constraints:
1. Your returned state must be able to be pickled.
2. You can only use one instance of that class in the Trainer callbacks list. We don't support persisting state for multiple callbacks of the same class.
Best Practices
--------------
The following are best practices when using/designing callbacks.

View File

@ -155,3 +155,12 @@ class Callback(abc.ABC):
def on_keyboard_interrupt(self, trainer, pl_module):
"""Called when the training is interrupted by KeyboardInterrupt."""
pass
def on_save_checkpoint(self, trainer, pl_module):
"""Called when saving a model checkpoint, use to persist state."""
pass
def on_load_checkpoint(self, checkpointed_state):
"""Called when loading a model checkpoint, use to reload state."""
pass

View File

@ -19,8 +19,6 @@ Early Stopping
Monitor a validation metric and stop training when it stops improving.
"""
from copy import deepcopy
import numpy as np
import torch
import torch.distributed as dist
@ -126,7 +124,7 @@ class EarlyStopping(Callback):
def monitor_op(self):
return self.mode_dict[self.mode]
def state_dict(self):
def on_save_checkpoint(self, trainer, pl_module):
return {
'wait_count': self.wait_count,
'stopped_epoch': self.stopped_epoch,
@ -134,12 +132,11 @@ class EarlyStopping(Callback):
'patience': self.patience
}
def load_state_dict(self, state_dict):
state_dict = deepcopy(state_dict)
self.wait_count = state_dict['wait_count']
self.stopped_epoch = state_dict['stopped_epoch']
self.best_score = state_dict['best_score']
self.patience = state_dict['patience']
def on_load_checkpoint(self, checkpointed_state):
self.wait_count = checkpointed_state['wait_count']
self.stopped_epoch = checkpointed_state['stopped_epoch']
self.best_score = checkpointed_state['best_score']
self.patience = checkpointed_state['patience']
def on_validation_end(self, trainer, pl_module):
if trainer.running_sanity_check:

View File

@ -426,3 +426,13 @@ class ModelCheckpoint(Callback):
for cur_path in del_list:
if cur_path != filepath:
self._del_model(cur_path)
def on_save_checkpoint(self, trainer, pl_module):
return {
'best_model_score': self.best_model_score,
'best_model_path': self.best_model_path,
}
def on_load_checkpoint(self, checkpointed_state):
self.best_model_score = checkpointed_state['best_model_score']
self.best_model_path = checkpointed_state['best_model_path']

View File

@ -13,6 +13,7 @@
# limitations under the License.
from abc import ABC
from copy import deepcopy
from typing import Callable, List
from pytorch_lightning.callbacks import Callback
@ -189,3 +190,22 @@ class TrainerCallbackHookMixin(ABC):
"""Called when the training is interrupted by KeyboardInterrupt."""
for callback in self.callbacks:
callback.on_keyboard_interrupt(self, self.get_model())
def on_save_checkpoint(self):
"""Called when saving a model checkpoint."""
callback_states = {}
for callback in self.callbacks:
callback_class = type(callback)
state = callback.on_save_checkpoint(self, self.get_model())
if state:
callback_states[callback_class] = state
return callback_states
def on_load_checkpoint(self, checkpoint):
"""Called when loading a model checkpoint."""
callback_states = checkpoint.get('callbacks')
for callback in self.callbacks:
state = callback_states.get(type(callback))
if state:
state = deepcopy(state)
callback.on_load_checkpoint(state)

View File

@ -109,16 +109,15 @@ import torch.distributed as torch_distrib
import pytorch_lightning
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.overrides.data_parallel import (
LightningDistributedDataParallel,
LightningDataParallel,
)
from pytorch_lightning.utilities import rank_zero_warn, AMPType
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
from pytorch_lightning.utilities.cloud_io import load as pl_load, atomic_save
from pytorch_lightning.overrides.data_parallel import LightningDataParallel, LightningDistributedDataParallel
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, gfile
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import makedirs
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
try:
import torch_xla
@ -340,20 +339,9 @@ class TrainerIOMixin(ABC):
if not weights_only:
# TODO support more generic way for callbacks to persist a state_dict in a checkpoint
checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
early_stopping_callbacks = [c for c in self.callbacks if isinstance(c, EarlyStopping)]
if checkpoint_callbacks:
# we add the official checkpoint callback to the end of the list
# extra user provided callbacks will not be persisted yet
checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE] = self.checkpoint_callback.best_model_score
checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH] = self.checkpoint_callback.best_model_path
if early_stopping_callbacks and checkpoint_callbacks:
# we add the official early stopping callback to the end of the list
# extra user provided callbacks will not be persisted yet
checkpoint['early_stop_callback_state_dict'] = early_stopping_callbacks[-1].state_dict()
# save callbacks
callback_states = self.on_save_checkpoint()
checkpoint['callbacks'] = callback_states
# save optimizers
optimizer_states = []
@ -424,25 +412,16 @@ class TrainerIOMixin(ABC):
' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
)
# TODO support more generic way for callbacks to load callback state_dicts
checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
early_stopping_callbacks = [c for c in self.callbacks if isinstance(c, EarlyStopping)]
if any([key in checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]):
raise ValueError(
"The checkpoint you're attempting to load follows an"
" outdated schema. You can upgrade to the current schema by running"
" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
" where `model.ckpt` is your checkpoint file."
)
if checkpoint_callbacks:
if ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE in checkpoint:
checkpoint_callbacks[-1].best_model_score = checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE]
else:
# Old naming until version 0.7.6
rank_zero_warn(
'Loading a checkpoint created with an old version of Lightning; '
'this will not be supported in the future.'
)
checkpoint_callbacks[-1].best_model_score = checkpoint['checkpoint_callback_best']
checkpoint_callbacks[-1].best_model_path = checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH]
if early_stopping_callbacks:
state = checkpoint['early_stop_callback_state_dict']
early_stopping_callbacks[-1].load_state_dict(state)
# load callback states
self.on_load_checkpoint(checkpoint)
self.global_step = checkpoint['global_step']
self.current_epoch = checkpoint['epoch']

View File

@ -0,0 +1,45 @@
import argparse
from shutil import copyfile
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
KEYS_MAPPING = {
"checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"),
"checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"),
"checkpoint_callback_best": (ModelCheckpoint, "best_model_score"),
"early_stop_callback_wait": (EarlyStopping, "wait_count"),
"early_stop_callback_patience": (EarlyStopping, "patience"),
}
def upgrade_checkpoint(filepath):
checkpoint = torch.load(filepath)
checkpoint["callbacks"] = checkpoint.get("callbacks") or {}
for key, new_path in KEYS_MAPPING.items():
if key in checkpoint:
value = checkpoint[key]
callback_type, callback_key = new_path
checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {}
checkpoint["callbacks"][callback_type][callback_key] = value
del checkpoint[key]
torch.save(checkpoint, filepath)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Upgrade an old checkpoint to the current schema. \
This will also save a backup of the original file."
)
parser.add_argument("--file", help="filepath for a checkpoint to upgrade")
args = parser.parse_args()
log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.")
copyfile(args.file, args.file + ".bak")
upgrade_checkpoint(args.file)

View File

@ -11,6 +11,23 @@ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from tests.base import EvalModelTemplate
class EarlyStoppingTestRestore(EarlyStopping):
# this class has to be defined outside the test function, otherwise we get pickle error
def __init__(self, expected_state=None):
super().__init__()
self.expected_state = expected_state
# cache the state for each epoch
self.saved_states = []
def on_train_start(self, trainer, pl_module):
if self.expected_state:
assert self.on_save_checkpoint(trainer, pl_module) == self.expected_state
def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
self.saved_states.append(self.on_save_checkpoint(trainer, pl_module).copy())
def test_resume_early_stopping_from_checkpoint(tmpdir):
"""
Prevent regressions to bugs:
@ -18,27 +35,9 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
https://github.com/PyTorchLightning/pytorch-lightning/issues/1463
"""
class EarlyStoppingTestStore(EarlyStopping):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# cache the state for each epoch
self.saved_states = []
def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
self.saved_states.append(deepcopy(self.state_dict()))
class EarlyStoppingTestRestore(EarlyStopping):
def __init__(self, expected_state):
super().__init__()
self.expected_state = expected_state
def on_train_start(self, trainer, pl_module):
assert self.state_dict() == self.expected_state
model = EvalModelTemplate()
checkpoint_callback = ModelCheckpoint(save_top_k=1)
early_stop_callback = EarlyStoppingTestStore()
early_stop_callback = EarlyStoppingTestRestore()
trainer = Trainer(
default_root_dir=tmpdir,
checkpoint_callback=checkpoint_callback,
@ -52,9 +51,9 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
# ensure state is persisted properly
checkpoint = torch.load(checkpoint_filepath)
# the checkpoint saves "epoch + 1"
early_stop_callback_state = early_stop_callback.saved_states[checkpoint['epoch'] - 1]
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1]
assert 4 == len(early_stop_callback.saved_states)
assert checkpoint['early_stop_callback_state_dict'] == early_stop_callback_state
assert checkpoint["callbacks"][type(early_stop_callback)] == early_stop_callback_state
# ensure state is reloaded properly (assertion in the callback)
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state)
@ -84,11 +83,10 @@ def test_early_stopping_no_extraneous_invocations(tmpdir):
assert len(trainer.dev_debugger.early_stopping_history) == expected_count
@pytest.mark.parametrize('loss_values, patience, expected_stop_epoch', [
([6, 5, 5, 5, 5, 5], 3, 4),
([6, 5, 4, 4, 3, 3], 1, 3),
([6, 5, 6, 5, 5, 5], 3, 4),
])
@pytest.mark.parametrize(
"loss_values, patience, expected_stop_epoch",
[([6, 5, 5, 5, 5, 5], 3, 4), ([6, 5, 4, 4, 3, 3], 1, 3), ([6, 5, 6, 5, 5, 5], 3, 4),],
)
def test_early_stopping_patience(tmpdir, loss_values, patience, expected_stop_epoch):
"""Test to ensure that early stopping is not triggered before patience is exhausted."""

View File

@ -15,7 +15,7 @@ from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import EvalModelTemplate
@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
""" Test that None in checkpoint callback is valid and that chkp_path is set correctly """
tutils.reset_seed()
@ -25,11 +25,11 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=checkpoint, overfit_batches=0.20, max_epochs=2)
trainer.fit(model)
assert checkpoint.dirpath == tmpdir / trainer.logger.name / 'version_0' / 'checkpoints'
assert checkpoint.dirpath == tmpdir / trainer.logger.name / "version_0" / "checkpoints"
@pytest.mark.parametrize(
'logger_version,expected', [(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')],
"logger_version,expected", [(None, "version_0"), (1, "version_1"), ("awesome", "awesome")],
)
def test_model_checkpoint_path(tmpdir, logger_version, expected):
"""Test that "version_" prefix is only added when logger's version is an integer"""
@ -86,7 +86,7 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):
num_epochs = 4
model_checkpoint = ModelCheckpointTestInvocations(expected_count=num_epochs, save_top_k=-1)
trainer = Trainer(
distributed_backend='ddp_cpu',
distributed_backend="ddp_cpu",
num_processes=2,
default_root_dir=tmpdir,
early_stop_callback=False,
@ -104,10 +104,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
num_epochs = 3
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=model_checkpoint,
max_epochs=num_epochs,
default_root_dir=tmpdir, early_stop_callback=False, checkpoint_callback=model_checkpoint, max_epochs=num_epochs,
)
trainer.fit(model)
path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) # epoch=3.ckpt
@ -115,15 +112,24 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
assert path_last_epoch != path_last
ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(path_last)
matching_keys = (
trainer_keys = (
"epoch",
"global_step",
ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE,
ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH,
)
for key in matching_keys:
for key in trainer_keys:
assert ckpt_last_epoch[key] == ckpt_last[key]
checkpoint_callback_keys = (
"best_model_score",
"best_model_path",
)
for key in checkpoint_callback_keys:
assert (
ckpt_last_epoch["callbacks"][type(model_checkpoint)][key]
== ckpt_last_epoch["callbacks"][type(model_checkpoint)][key]
)
# it is easier to load the model objects than to iterate over the raw dict of tensors
model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch)
model_last = EvalModelTemplate.load_from_checkpoint(path_last)

View File

@ -0,0 +1,36 @@
import pytest
import os
import torch
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.utilities.upgrade_checkpoint import upgrade_checkpoint
@pytest.mark.parametrize(
"old_checkpoint, new_checkpoint",
[
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}},
),
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}},
),
(
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": 'path'},
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": 'path'}}},
),
(
{"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4},
{"epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}},
),
],
)
def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint):
filepath = os.path.join(tmpdir, "model.ckpt")
torch.save(old_checkpoint, filepath)
upgrade_checkpoint(filepath)
updated_checkpoint = torch.load(filepath)
assert updated_checkpoint == new_checkpoint