Deprecate `LightningDataModule.on_save/load_checkpoint` (#11893)

This commit is contained in:
jjenniferdai 2022-03-07 18:21:46 -08:00 committed by GitHub
parent aea96e45a4
commit f3253070c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 54 additions and 5 deletions

View File

@ -482,6 +482,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `PrecisionPlugin.on_{save,load}_checkpoint` in favor of `PrecisionPlugin.{state_dict,load_state_dict}` ([#11978](https://github.com/PyTorchLightning/pytorch-lightning/pull/11978))
- Deprecated `LightningDataModule.on_save/load_checkpoint` in favor of `state_dict/load_state_dict` ([#11893](https://github.com/PyTorchLightning/pytorch-lightning/pull/11893))
### Removed
- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))

View File

@ -323,16 +323,16 @@ on_after_batch_transfer
.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_after_batch_transfer
:noindex:
on_load_checkpoint
load_state_dict
~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_load_checkpoint
.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.load_state_dict
:noindex:
on_save_checkpoint
state_dict
~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_save_checkpoint
.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.state_dict
:noindex:
on_train_dataloader

View File

@ -62,6 +62,8 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
_check_precision_plugin_checkpoint_hooks(trainer)
# TODO: Delete on_pretrain_routine_start/end hooks in v1.8
_check_on_pretrain_routine(model)
# TODO: Delete CheckpointHooks off LightningDataModule in v1.8
_check_datamodule_checkpoint_hooks(trainer)
def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
@ -395,3 +397,16 @@ def _check_precision_plugin_checkpoint_hooks(trainer: "pl.Trainer") -> None:
"`PrecisionPlugin.on_load_checkpoint` was deprecated in"
" v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
)
def _check_datamodule_checkpoint_hooks(trainer: "pl.Trainer") -> None:
if is_overridden(method_name="on_save_checkpoint", instance=trainer.datamodule):
rank_zero_deprecation(
"`LightningDataModule.on_save_checkpoint` was deprecated in"
" v1.6 and will be removed in v1.8. Use `state_dict` instead."
)
if is_overridden(method_name="on_load_checkpoint", instance=trainer.datamodule):
rank_zero_deprecation(
"`LightningDataModule.on_load_checkpoint` was deprecated in"
" v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
)

View File

@ -222,7 +222,11 @@ def test_dm_checkpoint_save_and_load(tmpdir):
)
# fit model
trainer.fit(model, datamodule=dm)
with pytest.deprecated_call(
match="`LightningDataModule.on_save_checkpoint` was deprecated in"
" v1.6 and will be removed in v1.8. Use `state_dict` instead."
):
trainer.fit(model, datamodule=dm)
assert trainer.state.finished, f"Training failed with {trainer.state}"
checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0]
checkpoint = torch.load(checkpoint_path)

View File

@ -36,6 +36,7 @@ from pytorch_lightning.plugins.training_type.single_device import SingleDevicePl
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.profiler import AbstractProfiler, AdvancedProfiler, SimpleProfiler
from pytorch_lightning.trainer.configuration_validator import _check_datamodule_checkpoint_hooks
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import DeviceType, DistributedType
@ -740,3 +741,29 @@ def test_v1_8_0_precision_plugin_checkpoint_hooks(tmpdir):
def test_v1_8_0_abstract_profiler():
assert "`AbstractProfiler` was deprecated in v1.6" in AbstractProfiler.__doc__
def test_v1_8_0_datamodule_checkpointhooks():
class CustomBoringDataModuleSave(BoringDataModule):
def on_save_checkpoint(self, checkpoint):
print("override on_save_checkpoint")
class CustomBoringDataModuleLoad(BoringDataModule):
def on_load_checkpoint(self, checkpoint):
print("override on_load_checkpoint")
trainer = Mock()
trainer.datamodule = CustomBoringDataModuleSave()
with pytest.deprecated_call(
match="`LightningDataModule.on_save_checkpoint` was deprecated in"
" v1.6 and will be removed in v1.8. Use `state_dict` instead."
):
_check_datamodule_checkpoint_hooks(trainer)
trainer.datamodule = CustomBoringDataModuleLoad()
with pytest.deprecated_call(
match="`LightningDataModule.on_load_checkpoint` was deprecated in"
" v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
):
_check_datamodule_checkpoint_hooks(trainer)