diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index c73ceb32ec..a41f87d418 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -397,9 +397,7 @@ class _CheckpointConnector: self.resume_start(checkpoint_path) self.restore_model() self.restore_datamodule() - if self.trainer.state.fn == TrainerFn.FITTING: - # restore callback states - self.restore_callbacks() + self.restore_callbacks() def dump_checkpoint(self, weights_only: bool = False) -> dict: """Creating a model checkpoint dictionary object from various component states. diff --git a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py index 39911f9edd..d29e2285e9 100644 --- a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py @@ -18,7 +18,7 @@ from unittest.mock import ANY, Mock import pytest import torch from lightning.pytorch import Trainer -from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.callbacks import Callback, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.migration.utils import _set_version @@ -234,3 +234,53 @@ def test_strict_loading(strict_loading, expected, tmp_path): trainer = Trainer(default_root_dir=tmp_path, barebones=True, max_steps=2) trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt")) model.load_state_dict.assert_called_once_with(ANY, strict=expected) + + +@pytest.mark.parametrize("trainer_fn", ["validate", "test", "predict"]) +def test_restore_callbacks_in_non_fit_phases(tmp_path, trainer_fn): + """Test that callbacks are properly restored in non-fit phases.""" + + class TestCallback(Callback): + def __init__(self): + self.restored = False + + def on_load_checkpoint(self, trainer, pl_module, checkpoint): + if "callbacks" in checkpoint: + callback_state = checkpoint["callbacks"][self.__class__.__name__] + self.restored = callback_state["restored"] + + def state_dict(self): + return {"restored": self.restored} + + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + checkpoint["callbacks"] = checkpoint.get("callbacks", {}) + checkpoint["callbacks"][self.__class__.__name__] = self.state_dict() + + # First create and train a model with the callback + callback = TestCallback() + model = BoringModel() + trainer = Trainer(default_root_dir=tmp_path, callbacks=[callback], max_steps=1) + trainer.fit(model) + + # Set the callback state to True before saving + callback.restored = True + ckpt_path = tmp_path / "checkpoint.ckpt" + trainer.save_checkpoint(ckpt_path) + + # Now create new instances and test restoration + new_callback = TestCallback() + new_model = BoringModel() + assert not new_callback.restored # Should start False + + new_trainer = Trainer(default_root_dir=tmp_path, callbacks=[new_callback]) + + # Connect the model and restore callbacks before evaluation + new_trainer.strategy.connect(new_model) + new_trainer._checkpoint_connector.resume_start(ckpt_path) + new_trainer._checkpoint_connector.restore_callbacks() + + # Run the evaluation phase (validate/test/predict) + fn = getattr(new_trainer, trainer_fn) + fn(new_model, ckpt_path=ckpt_path) + + assert new_callback.restored # Should be True after loading the checkpoint