From fe32b39dbc468cf21453a982869ab57ccc3b0d47 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 11 Oct 2022 15:18:47 -0400 Subject: [PATCH] Error messages for the remaining callback hooks (#15064) --- src/pytorch_lightning/_graveyard/__init__.py | 1 + src/pytorch_lightning/_graveyard/callbacks.py | 30 ++++++++++++++++ .../callbacks/model_checkpoint.py | 7 ---- .../trainer/configuration_validator.py | 8 +++++ .../deprecated_api/test_remove_2-0.py | 36 ++++++++++++++----- .../tests_pytorch/graveyard/test_callbacks.py | 26 ++++++++++++++ 6 files changed, 93 insertions(+), 15 deletions(-) create mode 100644 src/pytorch_lightning/_graveyard/callbacks.py create mode 100644 tests/tests_pytorch/graveyard/test_callbacks.py diff --git a/src/pytorch_lightning/_graveyard/__init__.py b/src/pytorch_lightning/_graveyard/__init__.py index 22f1df3e63..cc27bda688 100644 --- a/src/pytorch_lightning/_graveyard/__init__.py +++ b/src/pytorch_lightning/_graveyard/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytorch_lightning._graveyard.callbacks import pytorch_lightning._graveyard.trainer import pytorch_lightning._graveyard.training_type # noqa: F401 diff --git a/src/pytorch_lightning/_graveyard/callbacks.py b/src/pytorch_lightning/_graveyard/callbacks.py new file mode 100644 index 0000000000..ac972cb7ed --- /dev/null +++ b/src/pytorch_lightning/_graveyard/callbacks.py @@ -0,0 +1,30 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from pytorch_lightning.callbacks import ModelCheckpoint + + +def _save_checkpoint(_: ModelCheckpoint, __: Any) -> None: + # Remove in v2.0.0 + raise NotImplementedError( + f"`{ModelCheckpoint.__name__}.save_checkpoint()` was deprecated in v1.6 and is no longer supported" + f" as of 1.8. Please use `trainer.save_checkpoint()` to manually save a checkpoint. This method will be" + f" removed completely in v2.0." + ) + + +# Methods +ModelCheckpoint.save_checkpoint = _save_checkpoint diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index 256f913659..f95acbab61 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -351,13 +351,6 @@ class ModelCheckpoint(Checkpoint): self.best_model_path = state_dict["best_model_path"] - def save_checkpoint(self, trainer: "pl.Trainer") -> None: - raise NotImplementedError( - f"`{self.__class__.__name__}.save_checkpoint()` was deprecated in v1.6 and is no longer supported" - f" as of 1.8. Please use `trainer.save_checkpoint()` to manually save a checkpoint. This method will be" - f" removed completely in v2.0." - ) - def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: if self.save_top_k == 0: return diff --git a/src/pytorch_lightning/trainer/configuration_validator.py b/src/pytorch_lightning/trainer/configuration_validator.py index b14f955312..513382ca29 100644 --- a/src/pytorch_lightning/trainer/configuration_validator.py +++ b/src/pytorch_lightning/trainer/configuration_validator.py @@ -212,6 +212,14 @@ def _check_on_pretrain_routine(model: "pl.LightningModule") -> None: def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None: for callback in trainer.callbacks: + if callable(getattr(callback, "on_init_start", None)): + raise RuntimeError( + "The `on_init_start` callback hook was deprecated in v1.6 and is no longer supported as of v1.8." + ) + if callable(getattr(callback, "on_init_end", None)): + raise RuntimeError( + "The `on_init_end` callback hook was deprecated in v1.6 and is no longer supported as of v1.8." + ) if callable(getattr(callback, "on_configure_sharded_model", None)): raise RuntimeError( "The `on_configure_sharded_model` callback hook was removed in v1.8. Use `setup()` instead." diff --git a/tests/tests_pytorch/deprecated_api/test_remove_2-0.py b/tests/tests_pytorch/deprecated_api/test_remove_2-0.py index 9457f264fd..b082ed5394 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_2-0.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_2-0.py @@ -18,7 +18,6 @@ import pytest import pytorch_lightning from pytorch_lightning import Callback, Trainer -from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel from tests_pytorch.callbacks.test_callbacks import OldStatefulCallback from tests_pytorch.helpers.runif import RunIf @@ -292,11 +291,32 @@ def test_v2_0_0_callback_on_pretrain_routine_start_end(tmpdir): trainer.fit(model) -def test_v2_0_0_deprecated_mc_save_checkpoint(): - mc = ModelCheckpoint() - trainer = Trainer() - with mock.patch.object(trainer, "save_checkpoint"), pytest.raises( - NotImplementedError, - match=r"ModelCheckpoint.save_checkpoint\(\)` was deprecated in v1.6 and is no longer supported as of 1.8.", +class OnInitStartCallback(Callback): + def on_init_start(self, trainer): + print("Starting to init trainer!") + + +class OnInitEndCallback(Callback): + def on_init_end(self, trainer): + print("Trainer is init now") + + +@pytest.mark.parametrize("callback_class", [OnInitStartCallback, OnInitEndCallback]) +def test_v2_0_0_unsupported_on_init_start_end(callback_class, tmpdir): + model = BoringModel() + trainer = Trainer( + callbacks=[callback_class()], + max_epochs=1, + fast_dev_run=True, + enable_progress_bar=False, + logger=False, + default_root_dir=tmpdir, + ) + with pytest.raises( + RuntimeError, match="callback hook was deprecated in v1.6 and is no longer supported as of v1.8" ): - mc.save_checkpoint(trainer) + trainer.fit(model) + with pytest.raises( + RuntimeError, match="callback hook was deprecated in v1.6 and is no longer supported as of v1.8" + ): + trainer.validate(model) diff --git a/tests/tests_pytorch/graveyard/test_callbacks.py b/tests/tests_pytorch/graveyard/test_callbacks.py new file mode 100644 index 0000000000..17f8db135e --- /dev/null +++ b/tests/tests_pytorch/graveyard/test_callbacks.py @@ -0,0 +1,26 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from pytorch_lightning.callbacks import ModelCheckpoint + + +def test_v2_0_0_deprecated_mc_save_checkpoint(): + mc = ModelCheckpoint() + with pytest.raises( + NotImplementedError, + match=r"ModelCheckpoint.save_checkpoint\(\)` was deprecated in v1.6 and is no longer supported as of 1.8.", + ): + mc.save_checkpoint(None)