From f3253070c4e4f02b41f80a97d2e398964de9223e Mon Sep 17 00:00:00 2001 From: jjenniferdai <89552168+jjenniferdai@users.noreply.github.com> Date: Mon, 7 Mar 2022 18:21:46 -0800 Subject: [PATCH] Deprecate `LightningDataModule.on_save/load_checkpoint` (#11893) --- CHANGELOG.md | 3 +++ docs/source/extensions/datamodules.rst | 8 +++--- .../trainer/configuration_validator.py | 15 +++++++++++ tests/core/test_datamodules.py | 6 ++++- tests/deprecated_api/test_remove_1-8.py | 27 +++++++++++++++++++ 5 files changed, 54 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fd8d25ccaf..bbc65633bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -482,6 +482,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `PrecisionPlugin.on_{save,load}_checkpoint` in favor of `PrecisionPlugin.{state_dict,load_state_dict}` ([#11978](https://github.com/PyTorchLightning/pytorch-lightning/pull/11978)) +- Deprecated `LightningDataModule.on_save/load_checkpoint` in favor of `state_dict/load_state_dict` ([#11893](https://github.com/PyTorchLightning/pytorch-lightning/pull/11893)) + + ### Removed - Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507)) diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index 93eb9ce319..1bbcbcb83a 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -323,16 +323,16 @@ on_after_batch_transfer .. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_after_batch_transfer :noindex: -on_load_checkpoint +load_state_dict ~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_load_checkpoint +.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.load_state_dict :noindex: -on_save_checkpoint +state_dict ~~~~~~~~~~~~~~~~~~ -.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_save_checkpoint +.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.state_dict :noindex: on_train_dataloader diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 739d49c6c2..07ffba0860 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -62,6 +62,8 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None: _check_precision_plugin_checkpoint_hooks(trainer) # TODO: Delete on_pretrain_routine_start/end hooks in v1.8 _check_on_pretrain_routine(model) + # TODO: Delete CheckpointHooks off LightningDataModule in v1.8 + _check_datamodule_checkpoint_hooks(trainer) def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: @@ -395,3 +397,16 @@ def _check_precision_plugin_checkpoint_hooks(trainer: "pl.Trainer") -> None: "`PrecisionPlugin.on_load_checkpoint` was deprecated in" " v1.6 and will be removed in v1.8. Use `load_state_dict` instead." ) + + +def _check_datamodule_checkpoint_hooks(trainer: "pl.Trainer") -> None: + if is_overridden(method_name="on_save_checkpoint", instance=trainer.datamodule): + rank_zero_deprecation( + "`LightningDataModule.on_save_checkpoint` was deprecated in" + " v1.6 and will be removed in v1.8. Use `state_dict` instead." + ) + if is_overridden(method_name="on_load_checkpoint", instance=trainer.datamodule): + rank_zero_deprecation( + "`LightningDataModule.on_load_checkpoint` was deprecated in" + " v1.6 and will be removed in v1.8. Use `load_state_dict` instead." + ) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index b2a4f58762..f015d96fef 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -222,7 +222,11 @@ def test_dm_checkpoint_save_and_load(tmpdir): ) # fit model - trainer.fit(model, datamodule=dm) + with pytest.deprecated_call( + match="`LightningDataModule.on_save_checkpoint` was deprecated in" + " v1.6 and will be removed in v1.8. Use `state_dict` instead." + ): + trainer.fit(model, datamodule=dm) assert trainer.state.finished, f"Training failed with {trainer.state}" checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0] checkpoint = torch.load(checkpoint_path) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index bf24f81b8e..0fa174e172 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -36,6 +36,7 @@ from pytorch_lightning.plugins.training_type.single_device import SingleDevicePl from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin from pytorch_lightning.profiler import AbstractProfiler, AdvancedProfiler, SimpleProfiler +from pytorch_lightning.trainer.configuration_validator import _check_datamodule_checkpoint_hooks from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.enums import DeviceType, DistributedType @@ -740,3 +741,29 @@ def test_v1_8_0_precision_plugin_checkpoint_hooks(tmpdir): def test_v1_8_0_abstract_profiler(): assert "`AbstractProfiler` was deprecated in v1.6" in AbstractProfiler.__doc__ + + +def test_v1_8_0_datamodule_checkpointhooks(): + class CustomBoringDataModuleSave(BoringDataModule): + def on_save_checkpoint(self, checkpoint): + print("override on_save_checkpoint") + + class CustomBoringDataModuleLoad(BoringDataModule): + def on_load_checkpoint(self, checkpoint): + print("override on_load_checkpoint") + + trainer = Mock() + + trainer.datamodule = CustomBoringDataModuleSave() + with pytest.deprecated_call( + match="`LightningDataModule.on_save_checkpoint` was deprecated in" + " v1.6 and will be removed in v1.8. Use `state_dict` instead." + ): + _check_datamodule_checkpoint_hooks(trainer) + + trainer.datamodule = CustomBoringDataModuleLoad() + with pytest.deprecated_call( + match="`LightningDataModule.on_load_checkpoint` was deprecated in" + " v1.6 and will be removed in v1.8. Use `load_state_dict` instead." + ): + _check_datamodule_checkpoint_hooks(trainer)