Allow callbacks to be restored not just during training (#20403)
* Allow callbacks to be restored not just during training * add test case * test test case failure * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test case --------- Co-authored-by: Alan Chu <alanchu@Alans-Air.lan> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
This commit is contained in:
parent
cd2bd3ce53
commit
c110f4f3f6
|
@ -397,9 +397,7 @@ class _CheckpointConnector:
|
||||||
self.resume_start(checkpoint_path)
|
self.resume_start(checkpoint_path)
|
||||||
self.restore_model()
|
self.restore_model()
|
||||||
self.restore_datamodule()
|
self.restore_datamodule()
|
||||||
if self.trainer.state.fn == TrainerFn.FITTING:
|
self.restore_callbacks()
|
||||||
# restore callback states
|
|
||||||
self.restore_callbacks()
|
|
||||||
|
|
||||||
def dump_checkpoint(self, weights_only: bool = False) -> dict:
|
def dump_checkpoint(self, weights_only: bool = False) -> dict:
|
||||||
"""Creating a model checkpoint dictionary object from various component states.
|
"""Creating a model checkpoint dictionary object from various component states.
|
||||||
|
|
|
@ -18,7 +18,7 @@ from unittest.mock import ANY, Mock
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from lightning.pytorch import Trainer
|
from lightning.pytorch import Trainer
|
||||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||||
from lightning.pytorch.trainer.states import TrainerFn
|
from lightning.pytorch.trainer.states import TrainerFn
|
||||||
from lightning.pytorch.utilities.migration.utils import _set_version
|
from lightning.pytorch.utilities.migration.utils import _set_version
|
||||||
|
@ -234,3 +234,53 @@ def test_strict_loading(strict_loading, expected, tmp_path):
|
||||||
trainer = Trainer(default_root_dir=tmp_path, barebones=True, max_steps=2)
|
trainer = Trainer(default_root_dir=tmp_path, barebones=True, max_steps=2)
|
||||||
trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt"))
|
trainer.fit(model, ckpt_path=(tmp_path / "checkpoint.ckpt"))
|
||||||
model.load_state_dict.assert_called_once_with(ANY, strict=expected)
|
model.load_state_dict.assert_called_once_with(ANY, strict=expected)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("trainer_fn", ["validate", "test", "predict"])
|
||||||
|
def test_restore_callbacks_in_non_fit_phases(tmp_path, trainer_fn):
|
||||||
|
"""Test that callbacks are properly restored in non-fit phases."""
|
||||||
|
|
||||||
|
class TestCallback(Callback):
|
||||||
|
def __init__(self):
|
||||||
|
self.restored = False
|
||||||
|
|
||||||
|
def on_load_checkpoint(self, trainer, pl_module, checkpoint):
|
||||||
|
if "callbacks" in checkpoint:
|
||||||
|
callback_state = checkpoint["callbacks"][self.__class__.__name__]
|
||||||
|
self.restored = callback_state["restored"]
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return {"restored": self.restored}
|
||||||
|
|
||||||
|
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
||||||
|
checkpoint["callbacks"] = checkpoint.get("callbacks", {})
|
||||||
|
checkpoint["callbacks"][self.__class__.__name__] = self.state_dict()
|
||||||
|
|
||||||
|
# First create and train a model with the callback
|
||||||
|
callback = TestCallback()
|
||||||
|
model = BoringModel()
|
||||||
|
trainer = Trainer(default_root_dir=tmp_path, callbacks=[callback], max_steps=1)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
# Set the callback state to True before saving
|
||||||
|
callback.restored = True
|
||||||
|
ckpt_path = tmp_path / "checkpoint.ckpt"
|
||||||
|
trainer.save_checkpoint(ckpt_path)
|
||||||
|
|
||||||
|
# Now create new instances and test restoration
|
||||||
|
new_callback = TestCallback()
|
||||||
|
new_model = BoringModel()
|
||||||
|
assert not new_callback.restored # Should start False
|
||||||
|
|
||||||
|
new_trainer = Trainer(default_root_dir=tmp_path, callbacks=[new_callback])
|
||||||
|
|
||||||
|
# Connect the model and restore callbacks before evaluation
|
||||||
|
new_trainer.strategy.connect(new_model)
|
||||||
|
new_trainer._checkpoint_connector.resume_start(ckpt_path)
|
||||||
|
new_trainer._checkpoint_connector.restore_callbacks()
|
||||||
|
|
||||||
|
# Run the evaluation phase (validate/test/predict)
|
||||||
|
fn = getattr(new_trainer, trainer_fn)
|
||||||
|
fn(new_model, ckpt_path=ckpt_path)
|
||||||
|
|
||||||
|
assert new_callback.restored # Should be True after loading the checkpoint
|
||||||
|
|
Loading…
Reference in New Issue