From c110f4f3f60c643740f5e3573546abfcb5355315 Mon Sep 17 00:00:00 2001 From: Alan Chu <30797645+chualanagit@users.noreply.github.com> Date: Thu, 14 Nov 2024 14:46:19 -0800 Subject: [PATCH] Allow callbacks to be restored not just during training (#20403) * Allow callbacks to be restored not just during training * add test case * test test case failure * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test case --------- Co-authored-by: Alan Chu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Antiga --- .../connectors/checkpoint_connector.py | 4 +- .../connectors/test_checkpoint_connector.py | 52 ++++++++++++++++++- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index c73ceb32ec..a41f87d418 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -397,9 +397,7 @@ class _CheckpointConnector: self.resume_start(checkpoint_path) self.restore_model() self.restore_datamodule() - if self.trainer.state.fn == TrainerFn.FITTING: - # restore callback states - self.restore_callbacks() + self.restore_callbacks() def dump_checkpoint(self, weights_only: bool = False) -> dict: """Creating a model checkpoint dictionary object from various component states. diff --git a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py index 39911f9edd..d29e2285e9 100644 --- a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py @@ -18,7 +18,7 @@ from unittest.mock import ANY, Mock import pytest import torch from lightning.pytorch import Trainer -from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.callbacks import Callback, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.migration.utils import _set_version @@ -234,3 +234,53 @@ def test_strict_loading(strict_loading, expected, tmp_path): trainer = Trainer(default_root_dir=tmp_path, barebones=True, max_steps=2) trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt")) model.load_state_dict.assert_called_once_with(ANY, strict=expected) + + +@pytest.mark.parametrize("trainer_fn", ["validate", "test", "predict"]) +def test_restore_callbacks_in_non_fit_phases(tmp_path, trainer_fn): + """Test that callbacks are properly restored in non-fit phases.""" + + class TestCallback(Callback): + def __init__(self): + self.restored = False + + def on_load_checkpoint(self, trainer, pl_module, checkpoint): + if "callbacks" in checkpoint: + callback_state = checkpoint["callbacks"][self.__class__.__name__] + self.restored = callback_state["restored"] + + def state_dict(self): + return {"restored": self.restored} + + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + checkpoint["callbacks"] = checkpoint.get("callbacks", {}) + checkpoint["callbacks"][self.__class__.__name__] = self.state_dict() + + # First create and train a model with the callback + callback = TestCallback() + model = BoringModel() + trainer = Trainer(default_root_dir=tmp_path, callbacks=[callback], max_steps=1) + trainer.fit(model) + + # Set the callback state to True before saving + callback.restored = True + ckpt_path = tmp_path / "checkpoint.ckpt" + trainer.save_checkpoint(ckpt_path) + + # Now create new instances and test restoration + new_callback = TestCallback() + new_model = BoringModel() + assert not new_callback.restored # Should start False + + new_trainer = Trainer(default_root_dir=tmp_path, callbacks=[new_callback]) + + # Connect the model and restore callbacks before evaluation + new_trainer.strategy.connect(new_model) + new_trainer._checkpoint_connector.resume_start(ckpt_path) + new_trainer._checkpoint_connector.restore_callbacks() + + # Run the evaluation phase (validate/test/predict) + fn = getattr(new_trainer, trainer_fn) + fn(new_model, ckpt_path=ckpt_path) + + assert new_callback.restored # Should be True after loading the checkpoint