Deprecate `LightningDataModule.on_save/load_checkpoint` (#11893)
This commit is contained in:
parent
aea96e45a4
commit
f3253070c4
|
@ -482,6 +482,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated `PrecisionPlugin.on_{save,load}_checkpoint` in favor of `PrecisionPlugin.{state_dict,load_state_dict}` ([#11978](https://github.com/PyTorchLightning/pytorch-lightning/pull/11978))
|
||||
|
||||
|
||||
- Deprecated `LightningDataModule.on_save/load_checkpoint` in favor of `state_dict/load_state_dict` ([#11893](https://github.com/PyTorchLightning/pytorch-lightning/pull/11893))
|
||||
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))
|
||||
|
|
|
@ -323,16 +323,16 @@ on_after_batch_transfer
|
|||
.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_after_batch_transfer
|
||||
:noindex:
|
||||
|
||||
on_load_checkpoint
|
||||
load_state_dict
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_load_checkpoint
|
||||
.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.load_state_dict
|
||||
:noindex:
|
||||
|
||||
on_save_checkpoint
|
||||
state_dict
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.on_save_checkpoint
|
||||
.. automethod:: pytorch_lightning.core.datamodule.LightningDataModule.state_dict
|
||||
:noindex:
|
||||
|
||||
on_train_dataloader
|
||||
|
|
|
@ -62,6 +62,8 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
|
|||
_check_precision_plugin_checkpoint_hooks(trainer)
|
||||
# TODO: Delete on_pretrain_routine_start/end hooks in v1.8
|
||||
_check_on_pretrain_routine(model)
|
||||
# TODO: Delete CheckpointHooks off LightningDataModule in v1.8
|
||||
_check_datamodule_checkpoint_hooks(trainer)
|
||||
|
||||
|
||||
def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
|
||||
|
@ -395,3 +397,16 @@ def _check_precision_plugin_checkpoint_hooks(trainer: "pl.Trainer") -> None:
|
|||
"`PrecisionPlugin.on_load_checkpoint` was deprecated in"
|
||||
" v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
|
||||
)
|
||||
|
||||
|
||||
def _check_datamodule_checkpoint_hooks(trainer: "pl.Trainer") -> None:
|
||||
if is_overridden(method_name="on_save_checkpoint", instance=trainer.datamodule):
|
||||
rank_zero_deprecation(
|
||||
"`LightningDataModule.on_save_checkpoint` was deprecated in"
|
||||
" v1.6 and will be removed in v1.8. Use `state_dict` instead."
|
||||
)
|
||||
if is_overridden(method_name="on_load_checkpoint", instance=trainer.datamodule):
|
||||
rank_zero_deprecation(
|
||||
"`LightningDataModule.on_load_checkpoint` was deprecated in"
|
||||
" v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
|
||||
)
|
||||
|
|
|
@ -222,6 +222,10 @@ def test_dm_checkpoint_save_and_load(tmpdir):
|
|||
)
|
||||
|
||||
# fit model
|
||||
with pytest.deprecated_call(
|
||||
match="`LightningDataModule.on_save_checkpoint` was deprecated in"
|
||||
" v1.6 and will be removed in v1.8. Use `state_dict` instead."
|
||||
):
|
||||
trainer.fit(model, datamodule=dm)
|
||||
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
||||
checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0]
|
||||
|
|
|
@ -36,6 +36,7 @@ from pytorch_lightning.plugins.training_type.single_device import SingleDevicePl
|
|||
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
|
||||
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
|
||||
from pytorch_lightning.profiler import AbstractProfiler, AdvancedProfiler, SimpleProfiler
|
||||
from pytorch_lightning.trainer.configuration_validator import _check_datamodule_checkpoint_hooks
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities.enums import DeviceType, DistributedType
|
||||
|
@ -740,3 +741,29 @@ def test_v1_8_0_precision_plugin_checkpoint_hooks(tmpdir):
|
|||
|
||||
def test_v1_8_0_abstract_profiler():
|
||||
assert "`AbstractProfiler` was deprecated in v1.6" in AbstractProfiler.__doc__
|
||||
|
||||
|
||||
def test_v1_8_0_datamodule_checkpointhooks():
|
||||
class CustomBoringDataModuleSave(BoringDataModule):
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
print("override on_save_checkpoint")
|
||||
|
||||
class CustomBoringDataModuleLoad(BoringDataModule):
|
||||
def on_load_checkpoint(self, checkpoint):
|
||||
print("override on_load_checkpoint")
|
||||
|
||||
trainer = Mock()
|
||||
|
||||
trainer.datamodule = CustomBoringDataModuleSave()
|
||||
with pytest.deprecated_call(
|
||||
match="`LightningDataModule.on_save_checkpoint` was deprecated in"
|
||||
" v1.6 and will be removed in v1.8. Use `state_dict` instead."
|
||||
):
|
||||
_check_datamodule_checkpoint_hooks(trainer)
|
||||
|
||||
trainer.datamodule = CustomBoringDataModuleLoad()
|
||||
with pytest.deprecated_call(
|
||||
match="`LightningDataModule.on_load_checkpoint` was deprecated in"
|
||||
" v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
|
||||
):
|
||||
_check_datamodule_checkpoint_hooks(trainer)
|
||||
|
|
Loading…
Reference in New Issue