callback method for on_save_checkpoint (#2501)
* initial draft * fix test * Update pytorch_lightning/trainer/callback_hook.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * fix tests * remove old code * untested upgrade script * document limitations * clean up and add tests * Update pytorch_lightning/trainer/training_io.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * reflect PR comments * fix formatting * Update docs/source/callbacks.rst * clarify docs * revert change for loading checkpoints * small edits Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
79375e6d0a
commit
a5d1176cf6
|
@ -144,6 +144,18 @@ Lightning has a few built-in callbacks.
|
|||
|
||||
----------
|
||||
|
||||
Persisting State
|
||||
----------------
|
||||
|
||||
Some callbacks require internal state in order to function properly. You can optionally
|
||||
choose to persist your callback's state as part of model checkpoint files using the callback hooks
|
||||
:meth:`~pytorch_lightning.callbacks.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.Callback.on_load_checkpoint`.
|
||||
However, you must follow two constraints:
|
||||
|
||||
1. Your returned state must be able to be pickled.
|
||||
2. You can only use one instance of that class in the Trainer callbacks list. We don't support persisting state for multiple callbacks of the same class.
|
||||
|
||||
|
||||
Best Practices
|
||||
--------------
|
||||
The following are best practices when using/designing callbacks.
|
||||
|
|
|
@ -155,3 +155,12 @@ class Callback(abc.ABC):
|
|||
|
||||
def on_keyboard_interrupt(self, trainer, pl_module):
|
||||
"""Called when the training is interrupted by KeyboardInterrupt."""
|
||||
pass
|
||||
|
||||
def on_save_checkpoint(self, trainer, pl_module):
|
||||
"""Called when saving a model checkpoint, use to persist state."""
|
||||
pass
|
||||
|
||||
def on_load_checkpoint(self, checkpointed_state):
|
||||
"""Called when loading a model checkpoint, use to reload state."""
|
||||
pass
|
||||
|
|
|
@ -19,8 +19,6 @@ Early Stopping
|
|||
Monitor a validation metric and stop training when it stops improving.
|
||||
|
||||
"""
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -126,7 +124,7 @@ class EarlyStopping(Callback):
|
|||
def monitor_op(self):
|
||||
return self.mode_dict[self.mode]
|
||||
|
||||
def state_dict(self):
|
||||
def on_save_checkpoint(self, trainer, pl_module):
|
||||
return {
|
||||
'wait_count': self.wait_count,
|
||||
'stopped_epoch': self.stopped_epoch,
|
||||
|
@ -134,12 +132,11 @@ class EarlyStopping(Callback):
|
|||
'patience': self.patience
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
state_dict = deepcopy(state_dict)
|
||||
self.wait_count = state_dict['wait_count']
|
||||
self.stopped_epoch = state_dict['stopped_epoch']
|
||||
self.best_score = state_dict['best_score']
|
||||
self.patience = state_dict['patience']
|
||||
def on_load_checkpoint(self, checkpointed_state):
|
||||
self.wait_count = checkpointed_state['wait_count']
|
||||
self.stopped_epoch = checkpointed_state['stopped_epoch']
|
||||
self.best_score = checkpointed_state['best_score']
|
||||
self.patience = checkpointed_state['patience']
|
||||
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
if trainer.running_sanity_check:
|
||||
|
|
|
@ -426,3 +426,13 @@ class ModelCheckpoint(Callback):
|
|||
for cur_path in del_list:
|
||||
if cur_path != filepath:
|
||||
self._del_model(cur_path)
|
||||
|
||||
def on_save_checkpoint(self, trainer, pl_module):
|
||||
return {
|
||||
'best_model_score': self.best_model_score,
|
||||
'best_model_path': self.best_model_path,
|
||||
}
|
||||
|
||||
def on_load_checkpoint(self, checkpointed_state):
|
||||
self.best_model_score = checkpointed_state['best_model_score']
|
||||
self.best_model_path = checkpointed_state['best_model_path']
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from typing import Callable, List
|
||||
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
|
@ -189,3 +190,22 @@ class TrainerCallbackHookMixin(ABC):
|
|||
"""Called when the training is interrupted by KeyboardInterrupt."""
|
||||
for callback in self.callbacks:
|
||||
callback.on_keyboard_interrupt(self, self.get_model())
|
||||
|
||||
def on_save_checkpoint(self):
|
||||
"""Called when saving a model checkpoint."""
|
||||
callback_states = {}
|
||||
for callback in self.callbacks:
|
||||
callback_class = type(callback)
|
||||
state = callback.on_save_checkpoint(self, self.get_model())
|
||||
if state:
|
||||
callback_states[callback_class] = state
|
||||
return callback_states
|
||||
|
||||
def on_load_checkpoint(self, checkpoint):
|
||||
"""Called when loading a model checkpoint."""
|
||||
callback_states = checkpoint.get('callbacks')
|
||||
for callback in self.callbacks:
|
||||
state = callback_states.get(type(callback))
|
||||
if state:
|
||||
state = deepcopy(state)
|
||||
callback.on_load_checkpoint(state)
|
||||
|
|
|
@ -109,16 +109,15 @@ import torch.distributed as torch_distrib
|
|||
|
||||
import pytorch_lightning
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.overrides.data_parallel import (
|
||||
LightningDistributedDataParallel,
|
||||
LightningDataParallel,
|
||||
)
|
||||
from pytorch_lightning.utilities import rank_zero_warn, AMPType
|
||||
from pytorch_lightning.utilities.cloud_io import gfile, makedirs
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load, atomic_save
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDataParallel, LightningDistributedDataParallel
|
||||
from pytorch_lightning.utilities import AMPType, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save, gfile
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.cloud_io import makedirs
|
||||
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
|
||||
|
||||
try:
|
||||
import torch_xla
|
||||
|
@ -340,20 +339,9 @@ class TrainerIOMixin(ABC):
|
|||
|
||||
if not weights_only:
|
||||
|
||||
# TODO support more generic way for callbacks to persist a state_dict in a checkpoint
|
||||
checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
|
||||
early_stopping_callbacks = [c for c in self.callbacks if isinstance(c, EarlyStopping)]
|
||||
|
||||
if checkpoint_callbacks:
|
||||
# we add the official checkpoint callback to the end of the list
|
||||
# extra user provided callbacks will not be persisted yet
|
||||
checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE] = self.checkpoint_callback.best_model_score
|
||||
checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH] = self.checkpoint_callback.best_model_path
|
||||
|
||||
if early_stopping_callbacks and checkpoint_callbacks:
|
||||
# we add the official early stopping callback to the end of the list
|
||||
# extra user provided callbacks will not be persisted yet
|
||||
checkpoint['early_stop_callback_state_dict'] = early_stopping_callbacks[-1].state_dict()
|
||||
# save callbacks
|
||||
callback_states = self.on_save_checkpoint()
|
||||
checkpoint['callbacks'] = callback_states
|
||||
|
||||
# save optimizers
|
||||
optimizer_states = []
|
||||
|
@ -424,25 +412,16 @@ class TrainerIOMixin(ABC):
|
|||
' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
|
||||
)
|
||||
|
||||
# TODO support more generic way for callbacks to load callback state_dicts
|
||||
checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
|
||||
early_stopping_callbacks = [c for c in self.callbacks if isinstance(c, EarlyStopping)]
|
||||
if any([key in checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]):
|
||||
raise ValueError(
|
||||
"The checkpoint you're attempting to load follows an"
|
||||
" outdated schema. You can upgrade to the current schema by running"
|
||||
" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
|
||||
" where `model.ckpt` is your checkpoint file."
|
||||
)
|
||||
|
||||
if checkpoint_callbacks:
|
||||
if ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE in checkpoint:
|
||||
checkpoint_callbacks[-1].best_model_score = checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE]
|
||||
else:
|
||||
# Old naming until version 0.7.6
|
||||
rank_zero_warn(
|
||||
'Loading a checkpoint created with an old version of Lightning; '
|
||||
'this will not be supported in the future.'
|
||||
)
|
||||
checkpoint_callbacks[-1].best_model_score = checkpoint['checkpoint_callback_best']
|
||||
checkpoint_callbacks[-1].best_model_path = checkpoint[ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH]
|
||||
|
||||
if early_stopping_callbacks:
|
||||
state = checkpoint['early_stop_callback_state_dict']
|
||||
early_stopping_callbacks[-1].load_state_dict(state)
|
||||
# load callback states
|
||||
self.on_load_checkpoint(checkpoint)
|
||||
|
||||
self.global_step = checkpoint['global_step']
|
||||
self.current_epoch = checkpoint['epoch']
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
import argparse
|
||||
from shutil import copyfile
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
|
||||
KEYS_MAPPING = {
|
||||
"checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"),
|
||||
"checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"),
|
||||
"checkpoint_callback_best": (ModelCheckpoint, "best_model_score"),
|
||||
"early_stop_callback_wait": (EarlyStopping, "wait_count"),
|
||||
"early_stop_callback_patience": (EarlyStopping, "patience"),
|
||||
}
|
||||
|
||||
|
||||
def upgrade_checkpoint(filepath):
|
||||
checkpoint = torch.load(filepath)
|
||||
checkpoint["callbacks"] = checkpoint.get("callbacks") or {}
|
||||
|
||||
for key, new_path in KEYS_MAPPING.items():
|
||||
if key in checkpoint:
|
||||
value = checkpoint[key]
|
||||
callback_type, callback_key = new_path
|
||||
checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {}
|
||||
checkpoint["callbacks"][callback_type][callback_key] = value
|
||||
del checkpoint[key]
|
||||
|
||||
torch.save(checkpoint, filepath)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Upgrade an old checkpoint to the current schema. \
|
||||
This will also save a backup of the original file."
|
||||
)
|
||||
parser.add_argument("--file", help="filepath for a checkpoint to upgrade")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.")
|
||||
copyfile(args.file, args.file + ".bak")
|
||||
upgrade_checkpoint(args.file)
|
|
@ -11,6 +11,23 @@ from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
|||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
class EarlyStoppingTestRestore(EarlyStopping):
|
||||
# this class has to be defined outside the test function, otherwise we get pickle error
|
||||
def __init__(self, expected_state=None):
|
||||
super().__init__()
|
||||
self.expected_state = expected_state
|
||||
# cache the state for each epoch
|
||||
self.saved_states = []
|
||||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
if self.expected_state:
|
||||
assert self.on_save_checkpoint(trainer, pl_module) == self.expected_state
|
||||
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
super().on_validation_end(trainer, pl_module)
|
||||
self.saved_states.append(self.on_save_checkpoint(trainer, pl_module).copy())
|
||||
|
||||
|
||||
def test_resume_early_stopping_from_checkpoint(tmpdir):
|
||||
"""
|
||||
Prevent regressions to bugs:
|
||||
|
@ -18,27 +35,9 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
|
|||
https://github.com/PyTorchLightning/pytorch-lightning/issues/1463
|
||||
"""
|
||||
|
||||
class EarlyStoppingTestStore(EarlyStopping):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# cache the state for each epoch
|
||||
self.saved_states = []
|
||||
|
||||
def on_validation_end(self, trainer, pl_module):
|
||||
super().on_validation_end(trainer, pl_module)
|
||||
self.saved_states.append(deepcopy(self.state_dict()))
|
||||
|
||||
class EarlyStoppingTestRestore(EarlyStopping):
|
||||
def __init__(self, expected_state):
|
||||
super().__init__()
|
||||
self.expected_state = expected_state
|
||||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
assert self.state_dict() == self.expected_state
|
||||
|
||||
model = EvalModelTemplate()
|
||||
checkpoint_callback = ModelCheckpoint(save_top_k=1)
|
||||
early_stop_callback = EarlyStoppingTestStore()
|
||||
early_stop_callback = EarlyStoppingTestRestore()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
|
@ -52,9 +51,9 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
|
|||
# ensure state is persisted properly
|
||||
checkpoint = torch.load(checkpoint_filepath)
|
||||
# the checkpoint saves "epoch + 1"
|
||||
early_stop_callback_state = early_stop_callback.saved_states[checkpoint['epoch'] - 1]
|
||||
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1]
|
||||
assert 4 == len(early_stop_callback.saved_states)
|
||||
assert checkpoint['early_stop_callback_state_dict'] == early_stop_callback_state
|
||||
assert checkpoint["callbacks"][type(early_stop_callback)] == early_stop_callback_state
|
||||
|
||||
# ensure state is reloaded properly (assertion in the callback)
|
||||
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state)
|
||||
|
@ -84,11 +83,10 @@ def test_early_stopping_no_extraneous_invocations(tmpdir):
|
|||
assert len(trainer.dev_debugger.early_stopping_history) == expected_count
|
||||
|
||||
|
||||
@pytest.mark.parametrize('loss_values, patience, expected_stop_epoch', [
|
||||
([6, 5, 5, 5, 5, 5], 3, 4),
|
||||
([6, 5, 4, 4, 3, 3], 1, 3),
|
||||
([6, 5, 6, 5, 5, 5], 3, 4),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"loss_values, patience, expected_stop_epoch",
|
||||
[([6, 5, 5, 5, 5, 5], 3, 4), ([6, 5, 4, 4, 3, 3], 1, 3), ([6, 5, 6, 5, 5, 5], 3, 4),],
|
||||
)
|
||||
def test_early_stopping_patience(tmpdir, loss_values, patience, expected_stop_epoch):
|
||||
"""Test to ensure that early stopping is not triggered before patience is exhausted."""
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ from pytorch_lightning.loggers import TensorBoardLogger
|
|||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
||||
@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
|
||||
@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
|
||||
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
|
||||
""" Test that None in checkpoint callback is valid and that chkp_path is set correctly """
|
||||
tutils.reset_seed()
|
||||
|
@ -25,11 +25,11 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
|
|||
|
||||
trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=checkpoint, overfit_batches=0.20, max_epochs=2)
|
||||
trainer.fit(model)
|
||||
assert checkpoint.dirpath == tmpdir / trainer.logger.name / 'version_0' / 'checkpoints'
|
||||
assert checkpoint.dirpath == tmpdir / trainer.logger.name / "version_0" / "checkpoints"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'logger_version,expected', [(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')],
|
||||
"logger_version,expected", [(None, "version_0"), (1, "version_1"), ("awesome", "awesome")],
|
||||
)
|
||||
def test_model_checkpoint_path(tmpdir, logger_version, expected):
|
||||
"""Test that "version_" prefix is only added when logger's version is an integer"""
|
||||
|
@ -86,7 +86,7 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir):
|
|||
num_epochs = 4
|
||||
model_checkpoint = ModelCheckpointTestInvocations(expected_count=num_epochs, save_top_k=-1)
|
||||
trainer = Trainer(
|
||||
distributed_backend='ddp_cpu',
|
||||
distributed_backend="ddp_cpu",
|
||||
num_processes=2,
|
||||
default_root_dir=tmpdir,
|
||||
early_stop_callback=False,
|
||||
|
@ -104,10 +104,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
|
|||
num_epochs = 3
|
||||
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=num_epochs, save_last=True)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
early_stop_callback=False,
|
||||
checkpoint_callback=model_checkpoint,
|
||||
max_epochs=num_epochs,
|
||||
default_root_dir=tmpdir, early_stop_callback=False, checkpoint_callback=model_checkpoint, max_epochs=num_epochs,
|
||||
)
|
||||
trainer.fit(model)
|
||||
path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) # epoch=3.ckpt
|
||||
|
@ -115,15 +112,24 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
|
|||
assert path_last_epoch != path_last
|
||||
ckpt_last_epoch = torch.load(path_last_epoch)
|
||||
ckpt_last = torch.load(path_last)
|
||||
matching_keys = (
|
||||
|
||||
trainer_keys = (
|
||||
"epoch",
|
||||
"global_step",
|
||||
ModelCheckpoint.CHECKPOINT_STATE_BEST_SCORE,
|
||||
ModelCheckpoint.CHECKPOINT_STATE_BEST_PATH,
|
||||
)
|
||||
for key in matching_keys:
|
||||
for key in trainer_keys:
|
||||
assert ckpt_last_epoch[key] == ckpt_last[key]
|
||||
|
||||
checkpoint_callback_keys = (
|
||||
"best_model_score",
|
||||
"best_model_path",
|
||||
)
|
||||
for key in checkpoint_callback_keys:
|
||||
assert (
|
||||
ckpt_last_epoch["callbacks"][type(model_checkpoint)][key]
|
||||
== ckpt_last_epoch["callbacks"][type(model_checkpoint)][key]
|
||||
)
|
||||
|
||||
# it is easier to load the model objects than to iterate over the raw dict of tensors
|
||||
model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch)
|
||||
model_last = EvalModelTemplate.load_from_checkpoint(path_last)
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
import pytest
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.utilities.upgrade_checkpoint import upgrade_checkpoint
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"old_checkpoint, new_checkpoint",
|
||||
[
|
||||
(
|
||||
{"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34},
|
||||
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}},
|
||||
),
|
||||
(
|
||||
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99},
|
||||
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}},
|
||||
),
|
||||
(
|
||||
{"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": 'path'},
|
||||
{"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": 'path'}}},
|
||||
),
|
||||
(
|
||||
{"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4},
|
||||
{"epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint):
|
||||
filepath = os.path.join(tmpdir, "model.ckpt")
|
||||
torch.save(old_checkpoint, filepath)
|
||||
upgrade_checkpoint(filepath)
|
||||
updated_checkpoint = torch.load(filepath)
|
||||
assert updated_checkpoint == new_checkpoint
|
Loading…
Reference in New Issue