From 581bf7f2f20b770004e866b23505eba216780d2f Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 7 Feb 2022 19:45:27 +0530 Subject: [PATCH] Deprecate `on_epoch_start/on_epoch_end` hook (#11578) --- CHANGELOG.md | 12 +++++++ docs/source/advanced/profiler.rst | 6 ++-- docs/source/common/lightning_module.rst | 16 ---------- docs/source/extensions/callbacks.rst | 9 ------ pl_examples/basic_examples/autoencoder.py | 2 +- .../generative_adversarial_net.py | 2 +- pytorch_lightning/callbacks/base.py | 16 ++++++++-- pytorch_lightning/core/hooks.py | 14 +++++++-- .../trainer/configuration_validator.py | 25 +++++++++++++++ tests/callbacks/test_callback_hook_outputs.py | 2 +- .../checkpointing/test_legacy_checkpoints.py | 6 ++-- tests/core/test_metric_result_integration.py | 2 +- tests/deprecated_api/test_remove_1-8.py | 31 +++++++++++++++++++ tests/models/test_hooks.py | 1 - tests/models/test_restore.py | 3 -- .../logging_/test_eval_loop_logging.py | 10 ++---- .../logging_/test_train_loop_logging.py | 16 +++++++--- tests/trainer/test_trainer.py | 7 ++--- 18 files changed, 119 insertions(+), 61 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f18521f881..752e2086e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -323,6 +323,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated function `pytorch_lightning.callbacks.device_stats_monitor.prefix_metric_keys` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254)) +- Deprecated `Callback.on_epoch_start` hook in favour of `Callback.on_{train/val/test}_epoch_start` ([#11578](https://github.com/PyTorchLightning/pytorch-lightning/pull/11578)) + + +- Deprecated `Callback.on_epoch_end` hook in favour of `Callback.on_{train/val/test}_epoch_end` ([#11578](https://github.com/PyTorchLightning/pytorch-lightning/pull/11578)) + + +- Deprecated `LightningModule.on_epoch_start` hook in favor of `LightningModule.on_{train/val/test}_epoch_start` ([#11578](https://github.com/PyTorchLightning/pytorch-lightning/pull/11578)) + + +- Deprecated `LightningModule.on_epoch_end` hook in favor of `LightningModule.on_{train/val/test}_epoch_end` ([#11578](https://github.com/PyTorchLightning/pytorch-lightning/pull/11578)) + + - Deprecated `on_before_accelerator_backend_setup` callback hook in favour of `setup` ([#11568](https://github.com/PyTorchLightning/pytorch-lightning/pull/11568)) diff --git a/docs/source/advanced/profiler.rst b/docs/source/advanced/profiler.rst index 592edd7fa2..eb9bad0286 100644 --- a/docs/source/advanced/profiler.rst +++ b/docs/source/advanced/profiler.rst @@ -16,8 +16,8 @@ Built-in Actions PyTorch Lightning supports profiling standard actions in the training loop out of the box, including: -- on_epoch_start -- on_epoch_end +- on_train_epoch_start +- on_train_epoch_end - on_train_batch_start - model_forward - model_backward @@ -71,12 +71,10 @@ The profiler's results will be printed at the completion of a training ``trainer | on_train_batch_start | 0.00014637 | 0.0010246 | | [LightningModule]BoringModel.teardown | 2.15e-06 | 2.15e-06 | | [LightningModule]BoringModel.prepare_data | 1.955e-06 | 1.955e-06 | - | [LightningModule]BoringModel.on_epoch_end | 1.8373e-06 | 5.512e-06 | | [LightningModule]BoringModel.on_train_start | 1.644e-06 | 1.644e-06 | | [LightningModule]BoringModel.on_train_end | 1.516e-06 | 1.516e-06 | | [LightningModule]BoringModel.on_fit_end | 1.426e-06 | 1.426e-06 | | [LightningModule]BoringModel.setup | 1.403e-06 | 1.403e-06 | - | [LightningModule]BoringModel.on_epoch_start | 1.2883e-06 | 3.865e-06 | | [LightningModule]BoringModel.on_fit_start | 1.226e-06 | 1.226e-06 | ----------------------------------------------------------------------------------------------- diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index e416d329ec..f38f6ae5de 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1223,7 +1223,6 @@ for more information. def fit_loop(): - on_epoch_start() on_train_epoch_start() for batch in train_dataloader(): @@ -1254,7 +1253,6 @@ for more information. training_epoch_end() on_train_epoch_end() - on_epoch_end() def val_loop(): @@ -1262,7 +1260,6 @@ for more information. torch.set_grad_enabled(False) on_validation_start() - on_epoch_start() on_validation_epoch_start() val_outs = [] @@ -1281,7 +1278,6 @@ for more information. validation_epoch_end(val_outs) on_validation_epoch_end() - on_epoch_end() on_validation_end() # set up for train @@ -1474,18 +1470,6 @@ on_train_batch_end .. automethod:: pytorch_lightning.core.lightning.LightningModule.on_train_batch_end :noindex: -on_epoch_start -~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_epoch_start - :noindex: - -on_epoch_end -~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_epoch_end - :noindex: - on_train_epoch_start ~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index 2cf5889b3b..ab4894de77 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -324,15 +324,6 @@ on_predict_epoch_end .. automethod:: pytorch_lightning.callbacks.Callback.on_predict_epoch_end :noindex: -on_epoch_start -~~~~~~~~~~~~~~ - -.. automethod:: pytorch_lightning.callbacks.Callback.on_epoch_start - :noindex: - -on_epoch_end -~~~~~~~~~~~~ - .. automethod:: pytorch_lightning.callbacks.Callback.on_epoch_end :noindex: diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index bd34389b21..496d7e6d9b 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -85,7 +85,7 @@ class ImageSampler(pl.callbacks.Callback): ) @rank_zero_only - def on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: if not _TORCHVISION_AVAILABLE: return diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 26a6c8aa89..fd2cf69f14 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -200,7 +200,7 @@ class GAN(LightningModule): opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) return [opt_g, opt_d], [] - def on_epoch_end(self): + def on_train_epoch_end(self): z = self.validation_z.type_as(self.generator.model[0].weight) # log sampled images diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 461b552555..a24fef72e5 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -158,10 +158,22 @@ class Callback: """Called when the predict epoch ends.""" def on_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """Called when either of train/val/test epoch begins.""" + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. Use + ``on__epoch_start`` instead. + + Called when either of train/val/test epoch begins. + """ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - """Called when either of train/val/test epoch ends.""" + r""" + .. deprecated:: v1.6 + This callback hook was deprecated in v1.6 and will be removed in v1.8. Use + ``on__epoch_end`` instead. + + Called when either of train/val/test epoch ends. + """ def on_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: r""" diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index c5aafdecd1..b5a638d710 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -182,10 +182,20 @@ class ModelHooks: self.trainer.model.eval() def on_epoch_start(self) -> None: - """Called when either of train/val/test epoch begins.""" + r""" + .. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use + ``on__epoch_start`` instead. + + Called when either of train/val/test epoch begins. + """ def on_epoch_end(self) -> None: - """Called when either of train/val/test epoch ends.""" + r""" + .. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use + ``on__epoch_end`` instead. + + Called when either of train/val/test epoch ends. + """ def on_train_epoch_start(self) -> None: """Called in the training loop at the very beginning of the epoch.""" diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index bbf036e26b..72c3a90bfb 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -57,6 +57,8 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None: _check_on_init_start_end(trainer) # TODO: Delete _check_on_hpc_hooks in v1.8 _check_on_hpc_hooks(model) + # TODO: Delete on_epoch_start/on_epoch_end hooks in v1.8 + _check_on_epoch_start_end(trainer, model) # TODO: Delete on_batch_start/on_batch_end hooks in v1.8 _check_on_batch_start_end(trainer, model) # TODO: Remove this in v1.8 @@ -330,6 +332,29 @@ def _check_on_hpc_hooks(model: "pl.LightningModule") -> None: ) +# TODO: Remove on_epoch_start/on_epoch_end hooks in v1.8 +def _check_on_epoch_start_end(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: + hooks = ( + ["on_epoch_start", "on__epoch_start"], + ["on_epoch_end", "on__epoch_end"], + ) + + for hook, alternative_hook in hooks: + if is_overridden(hook, model): + rank_zero_deprecation( + f"The `LightningModule.{hook}` hook was deprecated in v1.6 and" + f" will be removed in v1.8. Please use `LightningModule.{alternative_hook}` instead." + ) + + for hook, alternative_hook in hooks: + for callback in trainer.callbacks: + if is_overridden(method_name=hook, instance=callback): + rank_zero_deprecation( + f"The `Callback.{hook}` hook was deprecated in v1.6 and" + f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead." + ) + + # TODO: Remove on_batch_start/on_batch_end hooks in v1.8 def _check_on_batch_start_end(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: hooks = (["on_batch_start", "on_train_batch_start"], ["on_batch_end", "on_train_batch_end"]) diff --git a/tests/callbacks/test_callback_hook_outputs.py b/tests/callbacks/test_callback_hook_outputs.py index f7c9321cd0..d4921a3486 100644 --- a/tests/callbacks/test_callback_hook_outputs.py +++ b/tests/callbacks/test_callback_hook_outputs.py @@ -63,7 +63,7 @@ def test_train_step_no_return(tmpdir, single_cb: bool): def test_free_memory_on_eval_outputs(tmpdir): class CB(Callback): - def on_epoch_end(self, trainer, pl_module): + def on_train_epoch_end(self, trainer, pl_module): assert not trainer._evaluation_loop._outputs model = BoringModel() diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index e26f02603f..7e753617a6 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -21,7 +21,6 @@ import torch import pytorch_lightning as pl from pytorch_lightning import Callback, Trainer -from pytorch_lightning.callbacks import EarlyStopping from tests import _PATH_LEGACY, _PROJECT_ROOT LEGACY_CHECKPOINTS_PATH = os.path.join(_PATH_LEGACY, "checkpoints") @@ -55,7 +54,7 @@ class LimitNbEpochs(Callback): self.limit = nb self._count = 0 - def on_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._count += 1 if self._count >= self.limit: trainer.should_stop = True @@ -73,7 +72,6 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str): dm = ClassifDataModule() model = ClassificationModel() - es = EarlyStopping(monitor="val_acc", mode="max", min_delta=0.005) stop = LimitNbEpochs(1) trainer = Trainer( @@ -81,7 +79,7 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str): accelerator="auto", devices=1, precision=(16 if torch.cuda.is_available() else 32), - callbacks=[es, stop], + callbacks=[stop], max_epochs=21, accumulate_grad_batches=2, ) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 05cdb4fe92..0ae5be4dff 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -432,7 +432,7 @@ def result_collection_reload(accelerator="auto", devices=1, **kwargs): return super().training_step(batch, batch_idx) - def on_epoch_end(self) -> None: + def on_train_epoch_end(self) -> None: if self.trainer.fit_loop.restarting: total = sum(range(5)) * devices metrics = self.results.metrics(on_step=False) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 049c53c68b..1f7a92d074 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -415,6 +415,37 @@ def test_v1_8_0_on_configure_sharded_model(tmpdir): trainer.fit(model) +def test_v1_8_0_remove_on_epoch_start_end_lightning_module(tmpdir): + class CustomModel(BoringModel): + def on_epoch_start(self, *args, **kwargs): + print("on_epoch_start") + + model = CustomModel() + trainer = Trainer( + fast_dev_run=True, + default_root_dir=tmpdir, + ) + with pytest.deprecated_call( + match="The `LightningModule.on_epoch_start` hook was deprecated in v1.6 and will be removed in v1.8" + ): + trainer.fit(model) + + class CustomModel(BoringModel): + def on_epoch_end(self, *args, **kwargs): + print("on_epoch_end") + + trainer = Trainer( + fast_dev_run=True, + default_root_dir=tmpdir, + ) + + model = CustomModel() + with pytest.deprecated_call( + match="The `LightningModule.on_epoch_end` hook was deprecated in v1.6 and will be removed in v1.8" + ): + trainer.fit(model) + + def test_v1_8_0_rank_zero_imports(): import warnings diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 5f20d7bb41..e9ea468c4a 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -755,7 +755,6 @@ def test_trainer_model_hook_system_predict(tmpdir): dict(name="zero_grad"), dict(name="Callback.on_predict_start", args=(trainer, model)), dict(name="on_predict_start"), - # TODO: `{,Callback}.on_epoch_{start,end}` dict(name="Callback.on_predict_epoch_start", args=(trainer, model)), dict(name="on_predict_epoch_start"), *model._predict_batch(trainer, model, batches), diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index cbc2b35553..8e1006ce73 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -50,9 +50,6 @@ class ModelTrainerPropertyParity(Callback): def on_train_batch_end(self, trainer, pl_module, *args, **kwargs): self._check_properties(trainer, pl_module) - def on_epoch_end(self, trainer, pl_module): - self._check_properties(trainer, pl_module) - def on_train_end(self, trainer, pl_module): self._check_properties(trainer, pl_module) diff --git a/tests/trainer/logging_/test_eval_loop_logging.py b/tests/trainer/logging_/test_eval_loop_logging.py index 125f73e9f7..6388bfe6b2 100644 --- a/tests/trainer/logging_/test_eval_loop_logging.py +++ b/tests/trainer/logging_/test_eval_loop_logging.py @@ -347,7 +347,9 @@ def test_log_works_in_val_callback(tmpdir): max_epochs=1, callbacks=[cb], ) - trainer.fit(model) + # TODO: Update this test in v1.8 (#11578) + with pytest.deprecated_call(match="`Callback.on_epoch_start` hook was deprecated in v1.6"): + trainer.fit(model) assert cb.call_counter == { "on_validation_batch_end": 4, @@ -437,12 +439,6 @@ def test_log_works_in_test_callback(tmpdir): def on_test_start(self, _, pl_module): self.make_logging(pl_module, "on_test_start", on_steps=[False], on_epochs=[True], prob_bars=self.choices) - def on_epoch_start(self, trainer, pl_module): - if trainer.testing: - self.make_logging( - pl_module, "on_epoch_start", on_steps=[False], on_epochs=[True], prob_bars=self.choices - ) - def on_test_epoch_start(self, _, pl_module): self.make_logging( pl_module, "on_test_epoch_start", on_steps=[False], on_epochs=[True], prob_bars=self.choices diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 870bc59865..0a4247a28e 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -325,7 +325,10 @@ def test_log_works_in_train_callback(tmpdir): max_epochs=1, callbacks=[cb], ) - trainer.fit(model) + + # TODO: Update this test in v1.8 (#11578) + with pytest.deprecated_call(match="`Callback.on_epoch_start` hook was deprecated in v1.6"): + trainer.fit(model) # Make sure the func_name output equals the average from all logged values when on_epoch true assert trainer.progress_bar_callback.get_metrics(trainer, model)["train_loss"] == model.seen_losses[-1] @@ -482,11 +485,11 @@ def test_progress_bar_metrics_contains_values_on_train_epoch_end(tmpdir: str): items.pop("v_num", None) return items - def on_epoch_end(self, trainer: Trainer, model: LightningModule): + def on_train_end(self, trainer: Trainer, model: LightningModule): metrics = self.get_metrics(trainer, model) assert metrics["foo"] == self.trainer.current_epoch assert metrics["foo_2"] == self.trainer.current_epoch - model.on_epoch_end_called = True + model.callback_on_train_end_called = True progress_bar = TestProgressBar() trainer = Trainer( @@ -502,7 +505,7 @@ def test_progress_bar_metrics_contains_values_on_train_epoch_end(tmpdir: str): model = TestModel() trainer.fit(model) assert model.on_train_epoch_end_called - assert model.on_epoch_end_called + assert model.callback_on_train_end_called def test_logging_in_callbacks_with_log_function(tmpdir): @@ -533,7 +536,10 @@ def test_logging_in_callbacks_with_log_function(tmpdir): enable_model_summary=False, callbacks=[LoggingCallback()], ) - trainer.fit(model) + + # TODO: Update this test in v1.8 (#11578) + with pytest.deprecated_call(match="`Callback.on_epoch_end` hook was deprecated in v1.6"): + trainer.fit(model) expected = { "on_train_start": 1, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 64f53fe308..8fc1ad488d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -412,7 +412,7 @@ def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ck num_batches_seen = 0 num_on_load_checkpoint_called = 0 - def on_epoch_end(self): + def on_train_epoch_end(self): self.num_epochs_end_seen += 1 def on_train_batch_start(self, *_): @@ -435,8 +435,7 @@ def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ck ) trainer.fit(model) - # `on_epoch_end` will be called once for val_sanity, twice for train, twice for val - assert model.num_epochs_end_seen == 1 + 2 + 2 + assert model.num_epochs_end_seen == 2 assert model.num_batches_seen == trainer.num_training_batches * 2 assert model.num_on_load_checkpoint_called == 0 @@ -1938,7 +1937,7 @@ def test_multiple_trainer_constant_memory_allocated(tmpdir): return torch.optim.Adam(self.layer.parameters(), lr=0.1) class Check(Callback): - def on_epoch_start(self, trainer, *_): + def on_train_epoch_start(self, trainer, *_): assert isinstance(trainer.strategy.model, DistributedDataParallel) def current_memory():