Remove deprecated precision plugin checkpoint hooks (#14833)

* Remove deprecated precision plugin checkpoint hooks
* chlog
This commit is contained in:
Adrian Wälchli 2022-09-22 13:19:14 +02:00 committed by GitHub
parent 1d3e971903
commit 9505e62350
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 4 additions and 57 deletions

View File

@ -199,9 +199,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the deprecated way to set the distributed backend via the environment variable `PL_TORCH_DISTRIBUTED_BACKEND`, in favor of setting the `process_group_backend` in the strategy constructor ([#14693](https://github.com/Lightning-AI/lightning/pull/14693))
- Removed the deprecated device attributes `Trainer.{devices,gpus,num_gpus,ipus,tpu_cores}` in favor of the accelerator-agnostic `Trainer.num_devices` ([#14829](https://github.com/Lightning-AI/lightning/pull/14829))
- Removed the deprecated precision plugin checkpoint hooks `PrecisionPlugin.on_load_checkpoint` and `PrecisionPlugin.on_save_checkpoint` ([#14833](https://github.com/Lightning-AI/lightning/pull/14833))
- Removed the deprecated `Trainer.root_gpu` attribute in favor of `Trainer.strategy.root_device` ([#14829](https://github.com/Lightning-AI/lightning/pull/14829))

View File

@ -269,15 +269,3 @@ class PrecisionPlugin(CheckpointHooks):
state_dict: the precision plugin state returned by ``state_dict``.
"""
pass
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""``PrecisionPlugin.on_save_checkpoint`` was deprecated in v1.6 and will be removed in v1.8.
Use ``state_dict`` instead.
"""
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
"""``PrecisionPlugin.on_load_checkpoint`` was deprecated in v1.6 and will be removed in v1.8.
Use ``load_state_dict`` instead.
"""

View File

@ -14,7 +14,6 @@
import pytorch_lightning as pl
from lightning_lite.utilities.warnings import PossibleUserWarning
from pytorch_lightning.accelerators.ipu import IPUAccelerator
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.strategies import DataParallelStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -50,8 +49,6 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
_check_deprecated_callback_hooks(trainer)
# TODO: Delete on_epoch_start/on_epoch_end hooks in v1.8
_check_on_epoch_start_end(model)
# TODO: Delete CheckpointHooks off PrecisionPlugin in v1.8
_check_precision_plugin_checkpoint_hooks(trainer)
# TODO: Delete on_pretrain_routine_start/end hooks in v1.8
_check_on_pretrain_routine(model)
# TODO: Delete CheckpointHooks off LightningDataModule in v1.8
@ -266,19 +263,6 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None:
)
def _check_precision_plugin_checkpoint_hooks(trainer: "pl.Trainer") -> None:
if is_overridden(method_name="on_save_checkpoint", instance=trainer.precision_plugin, parent=PrecisionPlugin):
rank_zero_deprecation(
"`PrecisionPlugin.on_save_checkpoint` was deprecated in"
" v1.6 and will be removed in v1.8. Use `state_dict` instead."
)
if is_overridden(method_name="on_load_checkpoint", instance=trainer.precision_plugin, parent=PrecisionPlugin):
rank_zero_deprecation(
"`PrecisionPlugin.on_load_checkpoint` was deprecated in"
" v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
)
def _check_datamodule_checkpoint_hooks(trainer: "pl.Trainer") -> None:
if is_overridden(method_name="on_save_checkpoint", instance=trainer.datamodule):
rank_zero_deprecation(

View File

@ -23,7 +23,6 @@ from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
from pytorch_lightning.loggers import CSVLogger, Logger
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.profilers import AdvancedProfiler, SimpleProfiler
from pytorch_lightning.strategies.ipu import LightningIPUModule
from pytorch_lightning.trainer.configuration_validator import _check_datamodule_checkpoint_hooks
@ -434,34 +433,6 @@ def test_simple_profiler_iterable_durations(tmpdir, action: str, expected: list)
np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2)
def test_v1_8_0_precision_plugin_checkpoint_hooks(tmpdir):
class PrecisionPluginSaveHook(PrecisionPlugin):
def on_save_checkpoint(self, checkpoint):
print("override on_save_checkpoint")
class PrecisionPluginLoadHook(PrecisionPlugin):
def on_load_checkpoint(self, checkpoint):
print("override on_load_checkpoint")
model = BoringModel()
precplugin_save = PrecisionPluginSaveHook()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, plugins=[precplugin_save])
with pytest.deprecated_call(
match="`PrecisionPlugin.on_save_checkpoint` was deprecated in"
" v1.6 and will be removed in v1.8. Use `state_dict` instead."
):
trainer.fit(model)
precplugin_load = PrecisionPluginLoadHook()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, plugins=[precplugin_load])
with pytest.deprecated_call(
match="`PrecisionPlugin.on_load_checkpoint` was deprecated in"
" v1.6 and will be removed in v1.8. Use `load_state_dict` instead."
):
trainer.fit(model)
def test_v1_8_0_datamodule_checkpointhooks():
class CustomBoringDataModuleSave(BoringDataModule):
def on_save_checkpoint(self, checkpoint):