Allow callbacks to be restored not just during training (#20403)

* Allow callbacks to be restored not just during training

* add test case

* test test case failure

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test case

---------

Co-authored-by: Alan Chu <alanchu@Alans-Air.lan>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
This commit is contained in:
Alan Chu 2024-11-14 14:46:19 -08:00 committed by GitHub
parent cd2bd3ce53
commit c110f4f3f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 4 deletions

View File

@ -397,9 +397,7 @@ class _CheckpointConnector:
self.resume_start(checkpoint_path) self.resume_start(checkpoint_path)
self.restore_model() self.restore_model()
self.restore_datamodule() self.restore_datamodule()
if self.trainer.state.fn == TrainerFn.FITTING: self.restore_callbacks()
# restore callback states
self.restore_callbacks()
def dump_checkpoint(self, weights_only: bool = False) -> dict: def dump_checkpoint(self, weights_only: bool = False) -> dict:
"""Creating a model checkpoint dictionary object from various component states. """Creating a model checkpoint dictionary object from various component states.

View File

@ -18,7 +18,7 @@ from unittest.mock import ANY, Mock
import pytest import pytest
import torch import torch
from lightning.pytorch import Trainer from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.migration.utils import _set_version from lightning.pytorch.utilities.migration.utils import _set_version
@ -234,3 +234,53 @@ def test_strict_loading(strict_loading, expected, tmp_path):
trainer = Trainer(default_root_dir=tmp_path, barebones=True, max_steps=2) trainer = Trainer(default_root_dir=tmp_path, barebones=True, max_steps=2)
trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt")) trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt"))
model.load_state_dict.assert_called_once_with(ANY, strict=expected) model.load_state_dict.assert_called_once_with(ANY, strict=expected)
@pytest.mark.parametrize("trainer_fn", ["validate", "test", "predict"])
def test_restore_callbacks_in_non_fit_phases(tmp_path, trainer_fn):
"""Test that callbacks are properly restored in non-fit phases."""
class TestCallback(Callback):
def __init__(self):
self.restored = False
def on_load_checkpoint(self, trainer, pl_module, checkpoint):
if "callbacks" in checkpoint:
callback_state = checkpoint["callbacks"][self.__class__.__name__]
self.restored = callback_state["restored"]
def state_dict(self):
return {"restored": self.restored}
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
checkpoint["callbacks"] = checkpoint.get("callbacks", {})
checkpoint["callbacks"][self.__class__.__name__] = self.state_dict()
# First create and train a model with the callback
callback = TestCallback()
model = BoringModel()
trainer = Trainer(default_root_dir=tmp_path, callbacks=[callback], max_steps=1)
trainer.fit(model)
# Set the callback state to True before saving
callback.restored = True
ckpt_path = tmp_path / "checkpoint.ckpt"
trainer.save_checkpoint(ckpt_path)
# Now create new instances and test restoration
new_callback = TestCallback()
new_model = BoringModel()
assert not new_callback.restored # Should start False
new_trainer = Trainer(default_root_dir=tmp_path, callbacks=[new_callback])
# Connect the model and restore callbacks before evaluation
new_trainer.strategy.connect(new_model)
new_trainer._checkpoint_connector.resume_start(ckpt_path)
new_trainer._checkpoint_connector.restore_callbacks()
# Run the evaluation phase (validate/test/predict)
fn = getattr(new_trainer, trainer_fn)
fn(new_model, ckpt_path=ckpt_path)
assert new_callback.restored # Should be True after loading the checkpoint