Allow callbacks to be restored not just during training (#20403)
* Allow callbacks to be restored not just during training * add test case * test test case failure * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test case --------- Co-authored-by: Alan Chu <alanchu@Alans-Air.lan> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
This commit is contained in:
parent
cd2bd3ce53
commit
c110f4f3f6
|
@ -397,8 +397,6 @@ class _CheckpointConnector:
|
|||
self.resume_start(checkpoint_path)
|
||||
self.restore_model()
|
||||
self.restore_datamodule()
|
||||
if self.trainer.state.fn == TrainerFn.FITTING:
|
||||
# restore callback states
|
||||
self.restore_callbacks()
|
||||
|
||||
def dump_checkpoint(self, weights_only: bool = False) -> dict:
|
||||
|
|
|
@ -18,7 +18,7 @@ from unittest.mock import ANY, Mock
|
|||
import pytest
|
||||
import torch
|
||||
from lightning.pytorch import Trainer
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
from lightning.pytorch.trainer.states import TrainerFn
|
||||
from lightning.pytorch.utilities.migration.utils import _set_version
|
||||
|
@ -234,3 +234,53 @@ def test_strict_loading(strict_loading, expected, tmp_path):
|
|||
trainer = Trainer(default_root_dir=tmp_path, barebones=True, max_steps=2)
|
||||
trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt"))
|
||||
model.load_state_dict.assert_called_once_with(ANY, strict=expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("trainer_fn", ["validate", "test", "predict"])
|
||||
def test_restore_callbacks_in_non_fit_phases(tmp_path, trainer_fn):
|
||||
"""Test that callbacks are properly restored in non-fit phases."""
|
||||
|
||||
class TestCallback(Callback):
|
||||
def __init__(self):
|
||||
self.restored = False
|
||||
|
||||
def on_load_checkpoint(self, trainer, pl_module, checkpoint):
|
||||
if "callbacks" in checkpoint:
|
||||
callback_state = checkpoint["callbacks"][self.__class__.__name__]
|
||||
self.restored = callback_state["restored"]
|
||||
|
||||
def state_dict(self):
|
||||
return {"restored": self.restored}
|
||||
|
||||
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
||||
checkpoint["callbacks"] = checkpoint.get("callbacks", {})
|
||||
checkpoint["callbacks"][self.__class__.__name__] = self.state_dict()
|
||||
|
||||
# First create and train a model with the callback
|
||||
callback = TestCallback()
|
||||
model = BoringModel()
|
||||
trainer = Trainer(default_root_dir=tmp_path, callbacks=[callback], max_steps=1)
|
||||
trainer.fit(model)
|
||||
|
||||
# Set the callback state to True before saving
|
||||
callback.restored = True
|
||||
ckpt_path = tmp_path / "checkpoint.ckpt"
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
# Now create new instances and test restoration
|
||||
new_callback = TestCallback()
|
||||
new_model = BoringModel()
|
||||
assert not new_callback.restored # Should start False
|
||||
|
||||
new_trainer = Trainer(default_root_dir=tmp_path, callbacks=[new_callback])
|
||||
|
||||
# Connect the model and restore callbacks before evaluation
|
||||
new_trainer.strategy.connect(new_model)
|
||||
new_trainer._checkpoint_connector.resume_start(ckpt_path)
|
||||
new_trainer._checkpoint_connector.restore_callbacks()
|
||||
|
||||
# Run the evaluation phase (validate/test/predict)
|
||||
fn = getattr(new_trainer, trainer_fn)
|
||||
fn(new_model, ckpt_path=ckpt_path)
|
||||
|
||||
assert new_callback.restored # Should be True after loading the checkpoint
|
||||
|
|
Loading…
Reference in New Issue