Deprecate `on_epoch_start/on_epoch_end` hook (#11578)

This commit is contained in:
Rohit Gupta 2022-02-07 19:45:27 +05:30 committed by GitHub
parent bbf27ed09a
commit 581bf7f2f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 119 additions and 61 deletions

View File

@ -323,6 +323,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated function `pytorch_lightning.callbacks.device_stats_monitor.prefix_metric_keys` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254))
- Deprecated `Callback.on_epoch_start` hook in favour of `Callback.on_{train/val/test}_epoch_start` ([#11578](https://github.com/PyTorchLightning/pytorch-lightning/pull/11578))
- Deprecated `Callback.on_epoch_end` hook in favour of `Callback.on_{train/val/test}_epoch_end` ([#11578](https://github.com/PyTorchLightning/pytorch-lightning/pull/11578))
- Deprecated `LightningModule.on_epoch_start` hook in favor of `LightningModule.on_{train/val/test}_epoch_start` ([#11578](https://github.com/PyTorchLightning/pytorch-lightning/pull/11578))
- Deprecated `LightningModule.on_epoch_end` hook in favor of `LightningModule.on_{train/val/test}_epoch_end` ([#11578](https://github.com/PyTorchLightning/pytorch-lightning/pull/11578))
- Deprecated `on_before_accelerator_backend_setup` callback hook in favour of `setup` ([#11568](https://github.com/PyTorchLightning/pytorch-lightning/pull/11568))

View File

@ -16,8 +16,8 @@ Built-in Actions
PyTorch Lightning supports profiling standard actions in the training loop out of the box, including:
- on_epoch_start
- on_epoch_end
- on_train_epoch_start
- on_train_epoch_end
- on_train_batch_start
- model_forward
- model_backward
@ -71,12 +71,10 @@ The profiler's results will be printed at the completion of a training ``trainer
| on_train_batch_start | 0.00014637 | 0.0010246 |
| [LightningModule]BoringModel.teardown | 2.15e-06 | 2.15e-06 |
| [LightningModule]BoringModel.prepare_data | 1.955e-06 | 1.955e-06 |
| [LightningModule]BoringModel.on_epoch_end | 1.8373e-06 | 5.512e-06 |
| [LightningModule]BoringModel.on_train_start | 1.644e-06 | 1.644e-06 |
| [LightningModule]BoringModel.on_train_end | 1.516e-06 | 1.516e-06 |
| [LightningModule]BoringModel.on_fit_end | 1.426e-06 | 1.426e-06 |
| [LightningModule]BoringModel.setup | 1.403e-06 | 1.403e-06 |
| [LightningModule]BoringModel.on_epoch_start | 1.2883e-06 | 3.865e-06 |
| [LightningModule]BoringModel.on_fit_start | 1.226e-06 | 1.226e-06 |
-----------------------------------------------------------------------------------------------

View File

@ -1223,7 +1223,6 @@ for more information.
def fit_loop():
on_epoch_start()
on_train_epoch_start()
for batch in train_dataloader():
@ -1254,7 +1253,6 @@ for more information.
training_epoch_end()
on_train_epoch_end()
on_epoch_end()
def val_loop():
@ -1262,7 +1260,6 @@ for more information.
torch.set_grad_enabled(False)
on_validation_start()
on_epoch_start()
on_validation_epoch_start()
val_outs = []
@ -1281,7 +1278,6 @@ for more information.
validation_epoch_end(val_outs)
on_validation_epoch_end()
on_epoch_end()
on_validation_end()
# set up for train
@ -1474,18 +1470,6 @@ on_train_batch_end
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_train_batch_end
:noindex:
on_epoch_start
~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_epoch_start
:noindex:
on_epoch_end
~~~~~~~~~~~~
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_epoch_end
:noindex:
on_train_epoch_start
~~~~~~~~~~~~~~~~~~~~

View File

@ -324,15 +324,6 @@ on_predict_epoch_end
.. automethod:: pytorch_lightning.callbacks.Callback.on_predict_epoch_end
:noindex:
on_epoch_start
~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_epoch_start
:noindex:
on_epoch_end
~~~~~~~~~~~~
.. automethod:: pytorch_lightning.callbacks.Callback.on_epoch_end
:noindex:

View File

@ -85,7 +85,7 @@ class ImageSampler(pl.callbacks.Callback):
)
@rank_zero_only
def on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if not _TORCHVISION_AVAILABLE:
return

View File

@ -200,7 +200,7 @@ class GAN(LightningModule):
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
return [opt_g, opt_d], []
def on_epoch_end(self):
def on_train_epoch_end(self):
z = self.validation_z.type_as(self.generator.model[0].weight)
# log sampled images

View File

@ -158,10 +158,22 @@ class Callback:
"""Called when the predict epoch ends."""
def on_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when either of train/val/test epoch begins."""
r"""
.. deprecated:: v1.6
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
``on_<train/validation/test>_epoch_start`` instead.
Called when either of train/val/test epoch begins.
"""
def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called when either of train/val/test epoch ends."""
r"""
.. deprecated:: v1.6
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
``on_<train/validation/test>_epoch_end`` instead.
Called when either of train/val/test epoch ends.
"""
def on_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
r"""

View File

@ -182,10 +182,20 @@ class ModelHooks:
self.trainer.model.eval()
def on_epoch_start(self) -> None:
"""Called when either of train/val/test epoch begins."""
r"""
.. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use
``on_<train/validation/test>_epoch_start`` instead.
Called when either of train/val/test epoch begins.
"""
def on_epoch_end(self) -> None:
"""Called when either of train/val/test epoch ends."""
r"""
.. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use
``on_<train/validation/test>_epoch_end`` instead.
Called when either of train/val/test epoch ends.
"""
def on_train_epoch_start(self) -> None:
"""Called in the training loop at the very beginning of the epoch."""

View File

@ -57,6 +57,8 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
_check_on_init_start_end(trainer)
# TODO: Delete _check_on_hpc_hooks in v1.8
_check_on_hpc_hooks(model)
# TODO: Delete on_epoch_start/on_epoch_end hooks in v1.8
_check_on_epoch_start_end(trainer, model)
# TODO: Delete on_batch_start/on_batch_end hooks in v1.8
_check_on_batch_start_end(trainer, model)
# TODO: Remove this in v1.8
@ -330,6 +332,29 @@ def _check_on_hpc_hooks(model: "pl.LightningModule") -> None:
)
# TODO: Remove on_epoch_start/on_epoch_end hooks in v1.8
def _check_on_epoch_start_end(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
hooks = (
["on_epoch_start", "on_<train/validation/test>_epoch_start"],
["on_epoch_end", "on_<train/validation/test>_epoch_end"],
)
for hook, alternative_hook in hooks:
if is_overridden(hook, model):
rank_zero_deprecation(
f"The `LightningModule.{hook}` hook was deprecated in v1.6 and"
f" will be removed in v1.8. Please use `LightningModule.{alternative_hook}` instead."
)
for hook, alternative_hook in hooks:
for callback in trainer.callbacks:
if is_overridden(method_name=hook, instance=callback):
rank_zero_deprecation(
f"The `Callback.{hook}` hook was deprecated in v1.6 and"
f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead."
)
# TODO: Remove on_batch_start/on_batch_end hooks in v1.8
def _check_on_batch_start_end(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
hooks = (["on_batch_start", "on_train_batch_start"], ["on_batch_end", "on_train_batch_end"])

View File

@ -63,7 +63,7 @@ def test_train_step_no_return(tmpdir, single_cb: bool):
def test_free_memory_on_eval_outputs(tmpdir):
class CB(Callback):
def on_epoch_end(self, trainer, pl_module):
def on_train_epoch_end(self, trainer, pl_module):
assert not trainer._evaluation_loop._outputs
model = BoringModel()

View File

@ -21,7 +21,6 @@ import torch
import pytorch_lightning as pl
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import EarlyStopping
from tests import _PATH_LEGACY, _PROJECT_ROOT
LEGACY_CHECKPOINTS_PATH = os.path.join(_PATH_LEGACY, "checkpoints")
@ -55,7 +54,7 @@ class LimitNbEpochs(Callback):
self.limit = nb
self._count = 0
def on_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._count += 1
if self._count >= self.limit:
trainer.should_stop = True
@ -73,7 +72,6 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str):
dm = ClassifDataModule()
model = ClassificationModel()
es = EarlyStopping(monitor="val_acc", mode="max", min_delta=0.005)
stop = LimitNbEpochs(1)
trainer = Trainer(
@ -81,7 +79,7 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str):
accelerator="auto",
devices=1,
precision=(16 if torch.cuda.is_available() else 32),
callbacks=[es, stop],
callbacks=[stop],
max_epochs=21,
accumulate_grad_batches=2,
)

View File

@ -432,7 +432,7 @@ def result_collection_reload(accelerator="auto", devices=1, **kwargs):
return super().training_step(batch, batch_idx)
def on_epoch_end(self) -> None:
def on_train_epoch_end(self) -> None:
if self.trainer.fit_loop.restarting:
total = sum(range(5)) * devices
metrics = self.results.metrics(on_step=False)

View File

@ -415,6 +415,37 @@ def test_v1_8_0_on_configure_sharded_model(tmpdir):
trainer.fit(model)
def test_v1_8_0_remove_on_epoch_start_end_lightning_module(tmpdir):
class CustomModel(BoringModel):
def on_epoch_start(self, *args, **kwargs):
print("on_epoch_start")
model = CustomModel()
trainer = Trainer(
fast_dev_run=True,
default_root_dir=tmpdir,
)
with pytest.deprecated_call(
match="The `LightningModule.on_epoch_start` hook was deprecated in v1.6 and will be removed in v1.8"
):
trainer.fit(model)
class CustomModel(BoringModel):
def on_epoch_end(self, *args, **kwargs):
print("on_epoch_end")
trainer = Trainer(
fast_dev_run=True,
default_root_dir=tmpdir,
)
model = CustomModel()
with pytest.deprecated_call(
match="The `LightningModule.on_epoch_end` hook was deprecated in v1.6 and will be removed in v1.8"
):
trainer.fit(model)
def test_v1_8_0_rank_zero_imports():
import warnings

View File

@ -755,7 +755,6 @@ def test_trainer_model_hook_system_predict(tmpdir):
dict(name="zero_grad"),
dict(name="Callback.on_predict_start", args=(trainer, model)),
dict(name="on_predict_start"),
# TODO: `{,Callback}.on_epoch_{start,end}`
dict(name="Callback.on_predict_epoch_start", args=(trainer, model)),
dict(name="on_predict_epoch_start"),
*model._predict_batch(trainer, model, batches),

View File

@ -50,9 +50,6 @@ class ModelTrainerPropertyParity(Callback):
def on_train_batch_end(self, trainer, pl_module, *args, **kwargs):
self._check_properties(trainer, pl_module)
def on_epoch_end(self, trainer, pl_module):
self._check_properties(trainer, pl_module)
def on_train_end(self, trainer, pl_module):
self._check_properties(trainer, pl_module)

View File

@ -347,7 +347,9 @@ def test_log_works_in_val_callback(tmpdir):
max_epochs=1,
callbacks=[cb],
)
trainer.fit(model)
# TODO: Update this test in v1.8 (#11578)
with pytest.deprecated_call(match="`Callback.on_epoch_start` hook was deprecated in v1.6"):
trainer.fit(model)
assert cb.call_counter == {
"on_validation_batch_end": 4,
@ -437,12 +439,6 @@ def test_log_works_in_test_callback(tmpdir):
def on_test_start(self, _, pl_module):
self.make_logging(pl_module, "on_test_start", on_steps=[False], on_epochs=[True], prob_bars=self.choices)
def on_epoch_start(self, trainer, pl_module):
if trainer.testing:
self.make_logging(
pl_module, "on_epoch_start", on_steps=[False], on_epochs=[True], prob_bars=self.choices
)
def on_test_epoch_start(self, _, pl_module):
self.make_logging(
pl_module, "on_test_epoch_start", on_steps=[False], on_epochs=[True], prob_bars=self.choices

View File

@ -325,7 +325,10 @@ def test_log_works_in_train_callback(tmpdir):
max_epochs=1,
callbacks=[cb],
)
trainer.fit(model)
# TODO: Update this test in v1.8 (#11578)
with pytest.deprecated_call(match="`Callback.on_epoch_start` hook was deprecated in v1.6"):
trainer.fit(model)
# Make sure the func_name output equals the average from all logged values when on_epoch true
assert trainer.progress_bar_callback.get_metrics(trainer, model)["train_loss"] == model.seen_losses[-1]
@ -482,11 +485,11 @@ def test_progress_bar_metrics_contains_values_on_train_epoch_end(tmpdir: str):
items.pop("v_num", None)
return items
def on_epoch_end(self, trainer: Trainer, model: LightningModule):
def on_train_end(self, trainer: Trainer, model: LightningModule):
metrics = self.get_metrics(trainer, model)
assert metrics["foo"] == self.trainer.current_epoch
assert metrics["foo_2"] == self.trainer.current_epoch
model.on_epoch_end_called = True
model.callback_on_train_end_called = True
progress_bar = TestProgressBar()
trainer = Trainer(
@ -502,7 +505,7 @@ def test_progress_bar_metrics_contains_values_on_train_epoch_end(tmpdir: str):
model = TestModel()
trainer.fit(model)
assert model.on_train_epoch_end_called
assert model.on_epoch_end_called
assert model.callback_on_train_end_called
def test_logging_in_callbacks_with_log_function(tmpdir):
@ -533,7 +536,10 @@ def test_logging_in_callbacks_with_log_function(tmpdir):
enable_model_summary=False,
callbacks=[LoggingCallback()],
)
trainer.fit(model)
# TODO: Update this test in v1.8 (#11578)
with pytest.deprecated_call(match="`Callback.on_epoch_end` hook was deprecated in v1.6"):
trainer.fit(model)
expected = {
"on_train_start": 1,

View File

@ -412,7 +412,7 @@ def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ck
num_batches_seen = 0
num_on_load_checkpoint_called = 0
def on_epoch_end(self):
def on_train_epoch_end(self):
self.num_epochs_end_seen += 1
def on_train_batch_start(self, *_):
@ -435,8 +435,7 @@ def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ck
)
trainer.fit(model)
# `on_epoch_end` will be called once for val_sanity, twice for train, twice for val
assert model.num_epochs_end_seen == 1 + 2 + 2
assert model.num_epochs_end_seen == 2
assert model.num_batches_seen == trainer.num_training_batches * 2
assert model.num_on_load_checkpoint_called == 0
@ -1938,7 +1937,7 @@ def test_multiple_trainer_constant_memory_allocated(tmpdir):
return torch.optim.Adam(self.layer.parameters(), lr=0.1)
class Check(Callback):
def on_epoch_start(self, trainer, *_):
def on_train_epoch_start(self, trainer, *_):
assert isinstance(trainer.strategy.model, DistributedDataParallel)
def current_memory():