diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 724ccca7cb..9b91171d37 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -126,6 +126,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - HPC checkpoints are now loaded automatically only in slurm environment when no specific value for `ckpt_path` has been set ([#14911](https://github.com/Lightning-AI/lightning/pull/14911)) +- The `Callback.on_load_checkpoint` now gets the full checkpoint dictionary and the `callback_state` argument was renamed `checkpoint` ([#14835](https://github.com/Lightning-AI/lightning/pull/14835)) + + ### Deprecated - Deprecated `LightningDeepSpeedModule` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000)) @@ -302,6 +305,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the deprecated `LightningDataModule.on_save/load_checkpoint` hooks ([#14909](https://github.com/Lightning-AI/lightning/pull/14909)) +- Removed support for returning a value in `Callback.on_save_checkpoint` in favor of implementing `Callback.state_dict` ([#14835](https://github.com/Lightning-AI/lightning/pull/14835)) + + + ### Fixed - Fixed an issue with `LightningLite.setup()` not setting the `.device` attribute correctly on the returned wrapper ([#14822](https://github.com/Lightning-AI/lightning/pull/14822)) diff --git a/src/pytorch_lightning/callbacks/callback.py b/src/pytorch_lightning/callbacks/callback.py index 484be2213e..f627b43cf5 100644 --- a/src/pytorch_lightning/callbacks/callback.py +++ b/src/pytorch_lightning/callbacks/callback.py @@ -277,7 +277,7 @@ class Callback: def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] - ) -> Optional[dict]: + ) -> None: r""" Called when saving a checkpoint to give you a chance to store anything else you might want to save. @@ -285,18 +285,10 @@ class Callback: trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance. pl_module: the current :class:`~pytorch_lightning.core.module.LightningModule` instance. checkpoint: the checkpoint dictionary that will be saved. - - Returns: - None or the callback state. Support for returning callback state will be removed in v1.8. - - .. deprecated:: v1.6 - Returning a value from this method was deprecated in v1.6 and will be removed in v1.8. - Implement ``Callback.state_dict`` instead to return state. - In v1.8 ``Callback.on_save_checkpoint`` can only return None. """ def on_load_checkpoint( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any] + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> None: r""" Called when loading a model checkpoint, use to reload state. @@ -304,18 +296,7 @@ class Callback: Args: trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance. pl_module: the current :class:`~pytorch_lightning.core.module.LightningModule` instance. - callback_state: the callback state returned by ``on_save_checkpoint``. - - Note: - The ``on_load_checkpoint`` won't be called with an undefined state. - If your ``on_load_checkpoint`` hook behavior doesn't rely on a state, - you will still need to override ``on_save_checkpoint`` to return a ``dummy state``. - - .. deprecated:: v1.6 - This callback hook will change its signature and behavior in v1.8. - If you wish to load the state of the callback, use ``Callback.load_state_dict`` instead. - In v1.8 ``Callback.on_load_checkpoint(checkpoint)`` will receive the entire loaded - checkpoint dictionary instead of only the callback state from the checkpoint. + checkpoint: the full checkpoint dictionary that got loaded by the Trainer. """ def on_before_backward(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: Tensor) -> None: diff --git a/src/pytorch_lightning/callbacks/pruning.py b/src/pytorch_lightning/callbacks/pruning.py index 7698b23741..ad5f8776c5 100644 --- a/src/pytorch_lightning/callbacks/pruning.py +++ b/src/pytorch_lightning/callbacks/pruning.py @@ -423,9 +423,7 @@ class ModelPruning(Callback): return apply_to_collection(state_dict, Tensor, move_to_cpu) - def on_save_checkpoint( - self, trainer: "pl.Trainer", pl_module: LightningModule, checkpoint: Dict[str, Any] - ) -> Optional[dict]: + def on_save_checkpoint(self, trainer: "pl.Trainer", pl_module: LightningModule, checkpoint: Dict[str, Any]) -> None: if self._make_pruning_permanent: rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint") # manually prune the weights so training can keep going with the same buffers diff --git a/src/pytorch_lightning/trainer/configuration_validator.py b/src/pytorch_lightning/trainer/configuration_validator.py index 9287abc949..f363abf92f 100644 --- a/src/pytorch_lightning/trainer/configuration_validator.py +++ b/src/pytorch_lightning/trainer/configuration_validator.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect + import pytorch_lightning as pl from lightning_lite.utilities.warnings import PossibleUserWarning from pytorch_lightning.accelerators.ipu import IPUAccelerator @@ -220,12 +222,16 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None: "The `on_before_accelerator_backend_setup` callback hook was deprecated in" " v1.6 and will be removed in v1.8. Use `setup()` instead." ) - if is_overridden(method_name="on_load_checkpoint", instance=callback): - rank_zero_deprecation( - f"`{callback.__class__.__name__}.on_load_checkpoint` will change its signature and behavior in v1.8." + + has_legacy_argument = "callback_state" in inspect.signature(callback.on_load_checkpoint).parameters + if is_overridden(method_name="on_load_checkpoint", instance=callback) and has_legacy_argument: + # TODO: Remove this error message in v2.0 + raise RuntimeError( + f"`{callback.__class__.__name__}.on_load_checkpoint` has changed its signature and behavior in v1.8." " If you wish to load the state of the callback, use `load_state_dict` instead." - " In v1.8 `on_load_checkpoint(..., checkpoint)` will receive the entire loaded" - " checkpoint dictionary instead of callback state." + " As of 1.8, `on_load_checkpoint(..., checkpoint)` receives the entire loaded" + " checkpoint dictionary instead of the callback state. To continue using this hook and avoid this error" + " message, rename the `callback_state` argument to `checkpoint`." ) for hook, alternative_hook in ( diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 7d8844547b..130613e6fd 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -1354,11 +1354,7 @@ class Trainer: return callback_state_dicts def _call_callbacks_on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - """Called when saving a model checkpoint, calls every callback's `on_save_checkpoint` hook. - - Will be removed in v1.8: If state is returned, we insert the callback state into - ``checkpoint["callbacks"][Callback.state_key]``. It overrides ``state_dict`` if already present. - """ + """Called when saving a model checkpoint, calls every callback's `on_save_checkpoint` hook.""" pl_module = self.lightning_module if pl_module: prev_fx_name = pl_module._current_fx_name @@ -1367,13 +1363,13 @@ class Trainer: for callback in self.callbacks: with self.profiler.profile(f"[Callback]{callback.state_key}.on_save_checkpoint"): state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) - if state: - rank_zero_deprecation( - f"Returning a value from `{callback.__class__.__name__}.on_save_checkpoint` is deprecated in v1.6" - " and will be removed in v1.8. Please override `Callback.state_dict`" - " to return state to be saved." + if state is not None: + # TODO: Remove this error message in v2.0 + raise ValueError( + f"Returning a value from `{callback.__class__.__name__}.on_save_checkpoint` was deprecated in v1.6" + f" and is no longer supported as of v1.8. Please override `Callback.state_dict` to return state" + f" to be saved." ) - checkpoint["callbacks"][callback.state_key] = state if pl_module: # restore current_fx when nested context @@ -1406,11 +1402,8 @@ class Trainer: ) for callback in self.callbacks: - state = callback_states.get(callback.state_key, callback_states.get(callback._legacy_state_key)) - if state: - state = deepcopy(state) - with self.profiler.profile(f"[Callback]{callback.state_key}.on_load_checkpoint"): - callback.on_load_checkpoint(self, self.lightning_module, state) + with self.profiler.profile(f"[Callback]{callback.state_key}.on_load_checkpoint"): + callback.on_load_checkpoint(self, self.lightning_module, checkpoint) if pl_module: # restore current_fx when nested context diff --git a/tests/tests_pytorch/callbacks/test_callbacks.py b/tests/tests_pytorch/callbacks/test_callbacks.py index d8664c6a1b..c8dca41305 100644 --- a/tests/tests_pytorch/callbacks/test_callbacks.py +++ b/tests/tests_pytorch/callbacks/test_callbacks.py @@ -133,44 +133,6 @@ def test_resume_callback_state_saved_by_type_stateful(tmpdir): assert callback.state == 111 -class OldStatefulCallbackHooks(Callback): - def __init__(self, state): - self.state = state - - @property - def state_key(self): - return type(self) - - def on_save_checkpoint(self, trainer, pl_module, checkpoint): - return {"state": self.state} - - def on_load_checkpoint(self, trainer, pl_module, callback_state): - self.state = callback_state["state"] - - -def test_resume_callback_state_saved_by_type_hooks(tmpdir): - """Test that a legacy checkpoint that didn't use a state key before can still be loaded, using deprecated - on_save/load_checkpoint signatures.""" - # TODO: remove old on_save/load_checkpoint signature support in v1.8 - # in favor of Stateful and new on_save/load_checkpoint signatures - # on_save_checkpoint() -> dict, on_load_checkpoint(callback_state) - # will become - # on_save_checkpoint() -> None and on_load_checkpoint(checkpoint) - model = BoringModel() - callback = OldStatefulCallbackHooks(state=111) - trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback]) - with pytest.deprecated_call(): - trainer.fit(model) - ckpt_path = Path(trainer.checkpoint_callback.best_model_path) - assert ckpt_path.exists() - - callback = OldStatefulCallbackHooks(state=222) - trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback]) - with pytest.deprecated_call(): - trainer.fit(model, ckpt_path=ckpt_path) - assert callback.state == 111 - - def test_resume_incomplete_callbacks_list_warning(tmpdir): model = BoringModel() callback0 = ModelCheckpoint(monitor="epoch") @@ -198,49 +160,3 @@ def test_resume_incomplete_callbacks_list_warning(tmpdir): ) with no_warning_call(UserWarning, match="Please add the following callbacks:"): trainer.fit(model, ckpt_path=ckpt_path) - - -class AllStatefulCallback(Callback): - def __init__(self, state): - self.state = state - - @property - def state_key(self): - return type(self) - - def state_dict(self): - return {"new_state": self.state} - - def load_state_dict(self, state_dict): - assert state_dict == {"old_state_precedence": 10} - self.state = state_dict["old_state_precedence"] - - def on_save_checkpoint(self, trainer, pl_module, checkpoint): - return {"old_state_precedence": 10} - - def on_load_checkpoint(self, trainer, pl_module, callback_state): - assert callback_state == {"old_state_precedence": 10} - self.old_state_precedence = callback_state["old_state_precedence"] - - -def test_resume_callback_state_all(tmpdir): - """Test on_save/load_checkpoint state precedence over state_dict/load_state_dict until v1.8 removal.""" - # TODO: remove old on_save/load_checkpoint signature support in v1.8 - # in favor of Stateful and new on_save/load_checkpoint signatures - # on_save_checkpoint() -> dict, on_load_checkpoint(callback_state) - # will become - # on_save_checkpoint() -> None and on_load_checkpoint(checkpoint) - model = BoringModel() - callback = AllStatefulCallback(state=111) - trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback]) - with pytest.deprecated_call(): - trainer.fit(model) - ckpt_path = Path(trainer.checkpoint_callback.best_model_path) - assert ckpt_path.exists() - - callback = AllStatefulCallback(state=222) - trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback]) - with pytest.deprecated_call(): - trainer.fit(model, ckpt_path=ckpt_path) - assert callback.state == 10 - assert callback.old_state_precedence == 10 diff --git a/tests/tests_pytorch/callbacks/test_lambda_function.py b/tests/tests_pytorch/callbacks/test_lambda_function.py index d7816a4f39..14a7bc54ef 100644 --- a/tests/tests_pytorch/callbacks/test_lambda_function.py +++ b/tests/tests_pytorch/callbacks/test_lambda_function.py @@ -36,7 +36,6 @@ def test_lambda_call(tmpdir): hooks = get_members(Callback) - {"state_dict", "load_state_dict"} hooks_args = {h: partial(call, h) for h in hooks} - hooks_args["on_save_checkpoint"] = lambda *_: [checker.add("on_save_checkpoint")] model = CustomModel() diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-8.py b/tests/tests_pytorch/deprecated_api/test_remove_1-8.py index c2fdc63241..ff4b80f889 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-8.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-8.py @@ -200,56 +200,3 @@ def test_deprecated_mc_save_checkpoint(): match=r"ModelCheckpoint.save_checkpoint\(\)` was deprecated in v1.6" ): mc.save_checkpoint(trainer) - - -def test_v1_8_0_callback_on_load_checkpoint_hook(tmpdir): - class TestCallbackLoadHook(Callback): - def on_load_checkpoint(self, trainer, pl_module, callback_state): - print("overriding on_load_checkpoint") - - model = BoringModel() - trainer = Trainer( - callbacks=[TestCallbackLoadHook()], - max_epochs=1, - fast_dev_run=True, - enable_progress_bar=False, - logger=False, - default_root_dir=tmpdir, - ) - with pytest.deprecated_call( - match="`TestCallbackLoadHook.on_load_checkpoint` will change its signature and behavior in v1.8." - " If you wish to load the state of the callback, use `load_state_dict` instead." - r" In v1.8 `on_load_checkpoint\(..., checkpoint\)` will receive the entire loaded" - " checkpoint dictionary instead of callback state." - ): - trainer.fit(model) - - -def test_v1_8_0_callback_on_save_checkpoint_hook(tmpdir): - class TestCallbackSaveHookReturn(Callback): - def on_save_checkpoint(self, trainer, pl_module, checkpoint): - return {"returning": "on_save_checkpoint"} - - class TestCallbackSaveHookOverride(Callback): - def on_save_checkpoint(self, trainer, pl_module, checkpoint): - print("overriding without returning") - - model = BoringModel() - trainer = Trainer( - callbacks=[TestCallbackSaveHookReturn()], - max_epochs=1, - fast_dev_run=True, - enable_progress_bar=False, - logger=False, - default_root_dir=tmpdir, - ) - trainer.fit(model) - with pytest.deprecated_call( - match="Returning a value from `TestCallbackSaveHookReturn.on_save_checkpoint` is deprecated in v1.6" - " and will be removed in v1.8. Please override `Callback.state_dict`" - " to return state to be saved." - ): - trainer.save_checkpoint(tmpdir + "/path.ckpt") - - trainer.callbacks = [TestCallbackSaveHookOverride()] - trainer.save_checkpoint(tmpdir + "/pathok.ckpt") diff --git a/tests/tests_pytorch/deprecated_api/test_remove_2-0.py b/tests/tests_pytorch/deprecated_api/test_remove_2-0.py index 548c7feec4..22188974c8 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_2-0.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_2-0.py @@ -17,7 +17,7 @@ from unittest import mock import pytest import pytorch_lightning -from pytorch_lightning import Trainer +from pytorch_lightning import Callback, Trainer from pytorch_lightning.demos.boring_classes import BoringModel from tests_pytorch.callbacks.test_callbacks import OldStatefulCallback from tests_pytorch.helpers.runif import RunIf @@ -46,7 +46,7 @@ def test_v2_0_0_deprecated_ipus(_, monkeypatch): _ = Trainer(ipus=4) -def test_v2_0_resume_from_checkpoint_trainer_constructor(tmpdir): +def test_v2_0_0_resume_from_checkpoint_trainer_constructor(tmpdir): # test resume_from_checkpoint still works until v2.0 deprecation model = BoringModel() callback = OldStatefulCallback(state=111) @@ -84,3 +84,55 @@ def test_v2_0_resume_from_checkpoint_trainer_constructor(tmpdir): trainer = Trainer(resume_from_checkpoint="trainer_arg_path") with pytest.raises(FileNotFoundError, match="Checkpoint at fit_arg_ckpt_path not found. Aborting training."): trainer.fit(model, ckpt_path="fit_arg_ckpt_path") + + +def test_v2_0_0_callback_on_load_checkpoint_hook(tmpdir): + class TestCallbackLoadHook(Callback): + def on_load_checkpoint(self, trainer, pl_module, callback_state): + print("overriding on_load_checkpoint") + + model = BoringModel() + trainer = Trainer( + callbacks=[TestCallbackLoadHook()], + max_epochs=1, + fast_dev_run=True, + enable_progress_bar=False, + logger=False, + default_root_dir=tmpdir, + ) + with pytest.raises( + RuntimeError, match="`TestCallbackLoadHook.on_load_checkpoint` has changed its signature and behavior in v1.8." + ): + trainer.fit(model) + + +def test_v2_0_0_callback_on_save_checkpoint_hook(tmpdir): + class TestCallbackSaveHookReturn(Callback): + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + return {"returning": "on_save_checkpoint"} + + class TestCallbackSaveHookOverride(Callback): + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + print("overriding without returning") + + model = BoringModel() + trainer = Trainer( + callbacks=[TestCallbackSaveHookReturn()], + max_epochs=1, + fast_dev_run=True, + enable_progress_bar=False, + logger=False, + default_root_dir=tmpdir, + ) + trainer.fit(model) + with pytest.raises( + ValueError, + match=( + "Returning a value from `TestCallbackSaveHookReturn.on_save_checkpoint` was deprecated in v1.6 and is" + " no longer supported as of v1.8" + ), + ): + trainer.save_checkpoint(tmpdir + "/path.ckpt") + + trainer.callbacks = [TestCallbackSaveHookOverride()] + trainer.save_checkpoint(tmpdir + "/pathok.ckpt") diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index e6ef33883b..a33d2a2a71 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -637,7 +637,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir): dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="fit")), dict(name="setup", kwargs=dict(stage="fit")), dict(name="on_load_checkpoint", args=(loaded_ckpt,)), - dict(name="Callback.on_load_checkpoint", args=(trainer, model, {"foo": True})), + dict(name="Callback.on_load_checkpoint", args=(trainer, model, loaded_ckpt)), dict(name="Callback.load_state_dict", args=({"foo": True},)), dict(name="configure_sharded_model"), dict(name="Callback.on_configure_sharded_model", args=(trainer, model)), @@ -726,7 +726,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir): dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="fit")), dict(name="setup", kwargs=dict(stage="fit")), dict(name="on_load_checkpoint", args=(loaded_ckpt,)), - dict(name="Callback.on_load_checkpoint", args=(trainer, model, {"foo": True})), + dict(name="Callback.on_load_checkpoint", args=(trainer, model, loaded_ckpt)), dict(name="Callback.load_state_dict", args=({"foo": True},)), dict(name="configure_sharded_model"), dict(name="Callback.on_configure_sharded_model", args=(trainer, model)),