Allow callbacks to be restored not just during training (#20403)

* Allow callbacks to be restored not just during training

* add test case

* test test case failure

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test case

---------

Co-authored-by: Alan Chu <alanchu@Alans-Air.lan>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
This commit is contained in:
Alan Chu 2024-11-14 14:46:19 -08:00 committed by GitHub
parent cd2bd3ce53
commit c110f4f3f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 4 deletions

View File

@ -397,8 +397,6 @@ class _CheckpointConnector:
self.resume_start(checkpoint_path)
self.restore_model()
self.restore_datamodule()
if self.trainer.state.fn == TrainerFn.FITTING:
# restore callback states
self.restore_callbacks()
def dump_checkpoint(self, weights_only: bool = False) -> dict:

View File

@ -18,7 +18,7 @@ from unittest.mock import ANY, Mock
import pytest
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.migration.utils import _set_version
@ -234,3 +234,53 @@ def test_strict_loading(strict_loading, expected, tmp_path):
trainer = Trainer(default_root_dir=tmp_path, barebones=True, max_steps=2)
trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt"))
model.load_state_dict.assert_called_once_with(ANY, strict=expected)
@pytest.mark.parametrize("trainer_fn", ["validate", "test", "predict"])
def test_restore_callbacks_in_non_fit_phases(tmp_path, trainer_fn):
"""Test that callbacks are properly restored in non-fit phases."""
class TestCallback(Callback):
def __init__(self):
self.restored = False
def on_load_checkpoint(self, trainer, pl_module, checkpoint):
if "callbacks" in checkpoint:
callback_state = checkpoint["callbacks"][self.__class__.__name__]
self.restored = callback_state["restored"]
def state_dict(self):
return {"restored": self.restored}
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
checkpoint["callbacks"] = checkpoint.get("callbacks", {})
checkpoint["callbacks"][self.__class__.__name__] = self.state_dict()
# First create and train a model with the callback
callback = TestCallback()
model = BoringModel()
trainer = Trainer(default_root_dir=tmp_path, callbacks=[callback], max_steps=1)
trainer.fit(model)
# Set the callback state to True before saving
callback.restored = True
ckpt_path = tmp_path / "checkpoint.ckpt"
trainer.save_checkpoint(ckpt_path)
# Now create new instances and test restoration
new_callback = TestCallback()
new_model = BoringModel()
assert not new_callback.restored # Should start False
new_trainer = Trainer(default_root_dir=tmp_path, callbacks=[new_callback])
# Connect the model and restore callbacks before evaluation
new_trainer.strategy.connect(new_model)
new_trainer._checkpoint_connector.resume_start(ckpt_path)
new_trainer._checkpoint_connector.restore_callbacks()
# Run the evaluation phase (validate/test/predict)
fn = getattr(new_trainer, trainer_fn)
fn(new_model, ckpt_path=ckpt_path)
assert new_callback.restored # Should be True after loading the checkpoint