From 9505e62350b6006cdfa4810ca7aeddcdf2a5b69e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 22 Sep 2022 13:19:14 +0200 Subject: [PATCH] Remove deprecated precision plugin checkpoint hooks (#14833) * Remove deprecated precision plugin checkpoint hooks * chlog --- src/pytorch_lightning/CHANGELOG.md | 4 +++ .../plugins/precision/precision_plugin.py | 12 -------- .../trainer/configuration_validator.py | 16 ---------- .../deprecated_api/test_remove_1-8.py | 29 ------------------- 4 files changed, 4 insertions(+), 57 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index d587d93a98..e2cbeefca7 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -199,9 +199,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the deprecated way to set the distributed backend via the environment variable `PL_TORCH_DISTRIBUTED_BACKEND`, in favor of setting the `process_group_backend` in the strategy constructor ([#14693](https://github.com/Lightning-AI/lightning/pull/14693)) + - Removed the deprecated device attributes `Trainer.{devices,gpus,num_gpus,ipus,tpu_cores}` in favor of the accelerator-agnostic `Trainer.num_devices` ([#14829](https://github.com/Lightning-AI/lightning/pull/14829)) +- Removed the deprecated precision plugin checkpoint hooks `PrecisionPlugin.on_load_checkpoint` and `PrecisionPlugin.on_save_checkpoint` ([#14833](https://github.com/Lightning-AI/lightning/pull/14833)) + + - Removed the deprecated `Trainer.root_gpu` attribute in favor of `Trainer.strategy.root_device` ([#14829](https://github.com/Lightning-AI/lightning/pull/14829)) diff --git a/src/pytorch_lightning/plugins/precision/precision_plugin.py b/src/pytorch_lightning/plugins/precision/precision_plugin.py index 063c8cabb7..790ab99707 100644 --- a/src/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/src/pytorch_lightning/plugins/precision/precision_plugin.py @@ -269,15 +269,3 @@ class PrecisionPlugin(CheckpointHooks): state_dict: the precision plugin state returned by ``state_dict``. """ pass - - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - """``PrecisionPlugin.on_save_checkpoint`` was deprecated in v1.6 and will be removed in v1.8. - - Use ``state_dict`` instead. - """ - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - """``PrecisionPlugin.on_load_checkpoint`` was deprecated in v1.6 and will be removed in v1.8. - - Use ``load_state_dict`` instead. - """ diff --git a/src/pytorch_lightning/trainer/configuration_validator.py b/src/pytorch_lightning/trainer/configuration_validator.py index 023ccb09bd..796445d7da 100644 --- a/src/pytorch_lightning/trainer/configuration_validator.py +++ b/src/pytorch_lightning/trainer/configuration_validator.py @@ -14,7 +14,6 @@ import pytorch_lightning as pl from lightning_lite.utilities.warnings import PossibleUserWarning from pytorch_lightning.accelerators.ipu import IPUAccelerator -from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.strategies import DataParallelStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -50,8 +49,6 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None: _check_deprecated_callback_hooks(trainer) # TODO: Delete on_epoch_start/on_epoch_end hooks in v1.8 _check_on_epoch_start_end(model) - # TODO: Delete CheckpointHooks off PrecisionPlugin in v1.8 - _check_precision_plugin_checkpoint_hooks(trainer) # TODO: Delete on_pretrain_routine_start/end hooks in v1.8 _check_on_pretrain_routine(model) # TODO: Delete CheckpointHooks off LightningDataModule in v1.8 @@ -266,19 +263,6 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None: ) -def _check_precision_plugin_checkpoint_hooks(trainer: "pl.Trainer") -> None: - if is_overridden(method_name="on_save_checkpoint", instance=trainer.precision_plugin, parent=PrecisionPlugin): - rank_zero_deprecation( - "`PrecisionPlugin.on_save_checkpoint` was deprecated in" - " v1.6 and will be removed in v1.8. Use `state_dict` instead." - ) - if is_overridden(method_name="on_load_checkpoint", instance=trainer.precision_plugin, parent=PrecisionPlugin): - rank_zero_deprecation( - "`PrecisionPlugin.on_load_checkpoint` was deprecated in" - " v1.6 and will be removed in v1.8. Use `load_state_dict` instead." - ) - - def _check_datamodule_checkpoint_hooks(trainer: "pl.Trainer") -> None: if is_overridden(method_name="on_save_checkpoint", instance=trainer.datamodule): rank_zero_deprecation( diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-8.py b/tests/tests_pytorch/deprecated_api/test_remove_1-8.py index ce391c869b..3064f2f260 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-8.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-8.py @@ -23,7 +23,6 @@ from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel from pytorch_lightning.loggers import CSVLogger, Logger -from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.profilers import AdvancedProfiler, SimpleProfiler from pytorch_lightning.strategies.ipu import LightningIPUModule from pytorch_lightning.trainer.configuration_validator import _check_datamodule_checkpoint_hooks @@ -434,34 +433,6 @@ def test_simple_profiler_iterable_durations(tmpdir, action: str, expected: list) np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2) -def test_v1_8_0_precision_plugin_checkpoint_hooks(tmpdir): - class PrecisionPluginSaveHook(PrecisionPlugin): - def on_save_checkpoint(self, checkpoint): - print("override on_save_checkpoint") - - class PrecisionPluginLoadHook(PrecisionPlugin): - def on_load_checkpoint(self, checkpoint): - print("override on_load_checkpoint") - - model = BoringModel() - - precplugin_save = PrecisionPluginSaveHook() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, plugins=[precplugin_save]) - with pytest.deprecated_call( - match="`PrecisionPlugin.on_save_checkpoint` was deprecated in" - " v1.6 and will be removed in v1.8. Use `state_dict` instead." - ): - trainer.fit(model) - - precplugin_load = PrecisionPluginLoadHook() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, plugins=[precplugin_load]) - with pytest.deprecated_call( - match="`PrecisionPlugin.on_load_checkpoint` was deprecated in" - " v1.6 and will be removed in v1.8. Use `load_state_dict` instead." - ): - trainer.fit(model) - - def test_v1_8_0_datamodule_checkpointhooks(): class CustomBoringDataModuleSave(BoringDataModule): def on_save_checkpoint(self, checkpoint):