Deprecate `on_epoch_start/on_epoch_end` hook (#11578)
This commit is contained in:
parent
bbf27ed09a
commit
581bf7f2f2
12
CHANGELOG.md
12
CHANGELOG.md
|
@ -323,6 +323,18 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Deprecated function `pytorch_lightning.callbacks.device_stats_monitor.prefix_metric_keys` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254))
|
||||
|
||||
|
||||
- Deprecated `Callback.on_epoch_start` hook in favour of `Callback.on_{train/val/test}_epoch_start` ([#11578](https://github.com/PyTorchLightning/pytorch-lightning/pull/11578))
|
||||
|
||||
|
||||
- Deprecated `Callback.on_epoch_end` hook in favour of `Callback.on_{train/val/test}_epoch_end` ([#11578](https://github.com/PyTorchLightning/pytorch-lightning/pull/11578))
|
||||
|
||||
|
||||
- Deprecated `LightningModule.on_epoch_start` hook in favor of `LightningModule.on_{train/val/test}_epoch_start` ([#11578](https://github.com/PyTorchLightning/pytorch-lightning/pull/11578))
|
||||
|
||||
|
||||
- Deprecated `LightningModule.on_epoch_end` hook in favor of `LightningModule.on_{train/val/test}_epoch_end` ([#11578](https://github.com/PyTorchLightning/pytorch-lightning/pull/11578))
|
||||
|
||||
|
||||
- Deprecated `on_before_accelerator_backend_setup` callback hook in favour of `setup` ([#11568](https://github.com/PyTorchLightning/pytorch-lightning/pull/11568))
|
||||
|
||||
|
||||
|
|
|
@ -16,8 +16,8 @@ Built-in Actions
|
|||
|
||||
PyTorch Lightning supports profiling standard actions in the training loop out of the box, including:
|
||||
|
||||
- on_epoch_start
|
||||
- on_epoch_end
|
||||
- on_train_epoch_start
|
||||
- on_train_epoch_end
|
||||
- on_train_batch_start
|
||||
- model_forward
|
||||
- model_backward
|
||||
|
@ -71,12 +71,10 @@ The profiler's results will be printed at the completion of a training ``trainer
|
|||
| on_train_batch_start | 0.00014637 | 0.0010246 |
|
||||
| [LightningModule]BoringModel.teardown | 2.15e-06 | 2.15e-06 |
|
||||
| [LightningModule]BoringModel.prepare_data | 1.955e-06 | 1.955e-06 |
|
||||
| [LightningModule]BoringModel.on_epoch_end | 1.8373e-06 | 5.512e-06 |
|
||||
| [LightningModule]BoringModel.on_train_start | 1.644e-06 | 1.644e-06 |
|
||||
| [LightningModule]BoringModel.on_train_end | 1.516e-06 | 1.516e-06 |
|
||||
| [LightningModule]BoringModel.on_fit_end | 1.426e-06 | 1.426e-06 |
|
||||
| [LightningModule]BoringModel.setup | 1.403e-06 | 1.403e-06 |
|
||||
| [LightningModule]BoringModel.on_epoch_start | 1.2883e-06 | 3.865e-06 |
|
||||
| [LightningModule]BoringModel.on_fit_start | 1.226e-06 | 1.226e-06 |
|
||||
-----------------------------------------------------------------------------------------------
|
||||
|
||||
|
|
|
@ -1223,7 +1223,6 @@ for more information.
|
|||
|
||||
|
||||
def fit_loop():
|
||||
on_epoch_start()
|
||||
on_train_epoch_start()
|
||||
|
||||
for batch in train_dataloader():
|
||||
|
@ -1254,7 +1253,6 @@ for more information.
|
|||
training_epoch_end()
|
||||
|
||||
on_train_epoch_end()
|
||||
on_epoch_end()
|
||||
|
||||
|
||||
def val_loop():
|
||||
|
@ -1262,7 +1260,6 @@ for more information.
|
|||
torch.set_grad_enabled(False)
|
||||
|
||||
on_validation_start()
|
||||
on_epoch_start()
|
||||
on_validation_epoch_start()
|
||||
|
||||
val_outs = []
|
||||
|
@ -1281,7 +1278,6 @@ for more information.
|
|||
validation_epoch_end(val_outs)
|
||||
|
||||
on_validation_epoch_end()
|
||||
on_epoch_end()
|
||||
on_validation_end()
|
||||
|
||||
# set up for train
|
||||
|
@ -1474,18 +1470,6 @@ on_train_batch_end
|
|||
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_train_batch_end
|
||||
:noindex:
|
||||
|
||||
on_epoch_start
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_epoch_start
|
||||
:noindex:
|
||||
|
||||
on_epoch_end
|
||||
~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_epoch_end
|
||||
:noindex:
|
||||
|
||||
on_train_epoch_start
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
@ -324,15 +324,6 @@ on_predict_epoch_end
|
|||
.. automethod:: pytorch_lightning.callbacks.Callback.on_predict_epoch_end
|
||||
:noindex:
|
||||
|
||||
on_epoch_start
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_epoch_start
|
||||
:noindex:
|
||||
|
||||
on_epoch_end
|
||||
~~~~~~~~~~~~
|
||||
|
||||
.. automethod:: pytorch_lightning.callbacks.Callback.on_epoch_end
|
||||
:noindex:
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ class ImageSampler(pl.callbacks.Callback):
|
|||
)
|
||||
|
||||
@rank_zero_only
|
||||
def on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
||||
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
|
||||
if not _TORCHVISION_AVAILABLE:
|
||||
return
|
||||
|
||||
|
|
|
@ -200,7 +200,7 @@ class GAN(LightningModule):
|
|||
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
|
||||
return [opt_g, opt_d], []
|
||||
|
||||
def on_epoch_end(self):
|
||||
def on_train_epoch_end(self):
|
||||
z = self.validation_z.type_as(self.generator.model[0].weight)
|
||||
|
||||
# log sampled images
|
||||
|
|
|
@ -158,10 +158,22 @@ class Callback:
|
|||
"""Called when the predict epoch ends."""
|
||||
|
||||
def on_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""Called when either of train/val/test epoch begins."""
|
||||
r"""
|
||||
.. deprecated:: v1.6
|
||||
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
|
||||
``on_<train/validation/test>_epoch_start`` instead.
|
||||
|
||||
Called when either of train/val/test epoch begins.
|
||||
"""
|
||||
|
||||
def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
"""Called when either of train/val/test epoch ends."""
|
||||
r"""
|
||||
.. deprecated:: v1.6
|
||||
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use
|
||||
``on_<train/validation/test>_epoch_end`` instead.
|
||||
|
||||
Called when either of train/val/test epoch ends.
|
||||
"""
|
||||
|
||||
def on_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
r"""
|
||||
|
|
|
@ -182,10 +182,20 @@ class ModelHooks:
|
|||
self.trainer.model.eval()
|
||||
|
||||
def on_epoch_start(self) -> None:
|
||||
"""Called when either of train/val/test epoch begins."""
|
||||
r"""
|
||||
.. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use
|
||||
``on_<train/validation/test>_epoch_start`` instead.
|
||||
|
||||
Called when either of train/val/test epoch begins.
|
||||
"""
|
||||
|
||||
def on_epoch_end(self) -> None:
|
||||
"""Called when either of train/val/test epoch ends."""
|
||||
r"""
|
||||
.. deprecated:: v1.6 This hook was deprecated in v1.6 and will be removed in v1.8. Use
|
||||
``on_<train/validation/test>_epoch_end`` instead.
|
||||
|
||||
Called when either of train/val/test epoch ends.
|
||||
"""
|
||||
|
||||
def on_train_epoch_start(self) -> None:
|
||||
"""Called in the training loop at the very beginning of the epoch."""
|
||||
|
|
|
@ -57,6 +57,8 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
|
|||
_check_on_init_start_end(trainer)
|
||||
# TODO: Delete _check_on_hpc_hooks in v1.8
|
||||
_check_on_hpc_hooks(model)
|
||||
# TODO: Delete on_epoch_start/on_epoch_end hooks in v1.8
|
||||
_check_on_epoch_start_end(trainer, model)
|
||||
# TODO: Delete on_batch_start/on_batch_end hooks in v1.8
|
||||
_check_on_batch_start_end(trainer, model)
|
||||
# TODO: Remove this in v1.8
|
||||
|
@ -330,6 +332,29 @@ def _check_on_hpc_hooks(model: "pl.LightningModule") -> None:
|
|||
)
|
||||
|
||||
|
||||
# TODO: Remove on_epoch_start/on_epoch_end hooks in v1.8
|
||||
def _check_on_epoch_start_end(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
|
||||
hooks = (
|
||||
["on_epoch_start", "on_<train/validation/test>_epoch_start"],
|
||||
["on_epoch_end", "on_<train/validation/test>_epoch_end"],
|
||||
)
|
||||
|
||||
for hook, alternative_hook in hooks:
|
||||
if is_overridden(hook, model):
|
||||
rank_zero_deprecation(
|
||||
f"The `LightningModule.{hook}` hook was deprecated in v1.6 and"
|
||||
f" will be removed in v1.8. Please use `LightningModule.{alternative_hook}` instead."
|
||||
)
|
||||
|
||||
for hook, alternative_hook in hooks:
|
||||
for callback in trainer.callbacks:
|
||||
if is_overridden(method_name=hook, instance=callback):
|
||||
rank_zero_deprecation(
|
||||
f"The `Callback.{hook}` hook was deprecated in v1.6 and"
|
||||
f" will be removed in v1.8. Please use `Callback.{alternative_hook}` instead."
|
||||
)
|
||||
|
||||
|
||||
# TODO: Remove on_batch_start/on_batch_end hooks in v1.8
|
||||
def _check_on_batch_start_end(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
|
||||
hooks = (["on_batch_start", "on_train_batch_start"], ["on_batch_end", "on_train_batch_end"])
|
||||
|
|
|
@ -63,7 +63,7 @@ def test_train_step_no_return(tmpdir, single_cb: bool):
|
|||
|
||||
def test_free_memory_on_eval_outputs(tmpdir):
|
||||
class CB(Callback):
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
assert not trainer._evaluation_loop._outputs
|
||||
|
||||
model = BoringModel()
|
||||
|
|
|
@ -21,7 +21,6 @@ import torch
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from tests import _PATH_LEGACY, _PROJECT_ROOT
|
||||
|
||||
LEGACY_CHECKPOINTS_PATH = os.path.join(_PATH_LEGACY, "checkpoints")
|
||||
|
@ -55,7 +54,7 @@ class LimitNbEpochs(Callback):
|
|||
self.limit = nb
|
||||
self._count = 0
|
||||
|
||||
def on_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
self._count += 1
|
||||
if self._count >= self.limit:
|
||||
trainer.should_stop = True
|
||||
|
@ -73,7 +72,6 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str):
|
|||
|
||||
dm = ClassifDataModule()
|
||||
model = ClassificationModel()
|
||||
es = EarlyStopping(monitor="val_acc", mode="max", min_delta=0.005)
|
||||
stop = LimitNbEpochs(1)
|
||||
|
||||
trainer = Trainer(
|
||||
|
@ -81,7 +79,7 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str):
|
|||
accelerator="auto",
|
||||
devices=1,
|
||||
precision=(16 if torch.cuda.is_available() else 32),
|
||||
callbacks=[es, stop],
|
||||
callbacks=[stop],
|
||||
max_epochs=21,
|
||||
accumulate_grad_batches=2,
|
||||
)
|
||||
|
|
|
@ -432,7 +432,7 @@ def result_collection_reload(accelerator="auto", devices=1, **kwargs):
|
|||
|
||||
return super().training_step(batch, batch_idx)
|
||||
|
||||
def on_epoch_end(self) -> None:
|
||||
def on_train_epoch_end(self) -> None:
|
||||
if self.trainer.fit_loop.restarting:
|
||||
total = sum(range(5)) * devices
|
||||
metrics = self.results.metrics(on_step=False)
|
||||
|
|
|
@ -415,6 +415,37 @@ def test_v1_8_0_on_configure_sharded_model(tmpdir):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_v1_8_0_remove_on_epoch_start_end_lightning_module(tmpdir):
|
||||
class CustomModel(BoringModel):
|
||||
def on_epoch_start(self, *args, **kwargs):
|
||||
print("on_epoch_start")
|
||||
|
||||
model = CustomModel()
|
||||
trainer = Trainer(
|
||||
fast_dev_run=True,
|
||||
default_root_dir=tmpdir,
|
||||
)
|
||||
with pytest.deprecated_call(
|
||||
match="The `LightningModule.on_epoch_start` hook was deprecated in v1.6 and will be removed in v1.8"
|
||||
):
|
||||
trainer.fit(model)
|
||||
|
||||
class CustomModel(BoringModel):
|
||||
def on_epoch_end(self, *args, **kwargs):
|
||||
print("on_epoch_end")
|
||||
|
||||
trainer = Trainer(
|
||||
fast_dev_run=True,
|
||||
default_root_dir=tmpdir,
|
||||
)
|
||||
|
||||
model = CustomModel()
|
||||
with pytest.deprecated_call(
|
||||
match="The `LightningModule.on_epoch_end` hook was deprecated in v1.6 and will be removed in v1.8"
|
||||
):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_v1_8_0_rank_zero_imports():
|
||||
|
||||
import warnings
|
||||
|
|
|
@ -755,7 +755,6 @@ def test_trainer_model_hook_system_predict(tmpdir):
|
|||
dict(name="zero_grad"),
|
||||
dict(name="Callback.on_predict_start", args=(trainer, model)),
|
||||
dict(name="on_predict_start"),
|
||||
# TODO: `{,Callback}.on_epoch_{start,end}`
|
||||
dict(name="Callback.on_predict_epoch_start", args=(trainer, model)),
|
||||
dict(name="on_predict_epoch_start"),
|
||||
*model._predict_batch(trainer, model, batches),
|
||||
|
|
|
@ -50,9 +50,6 @@ class ModelTrainerPropertyParity(Callback):
|
|||
def on_train_batch_end(self, trainer, pl_module, *args, **kwargs):
|
||||
self._check_properties(trainer, pl_module)
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
self._check_properties(trainer, pl_module)
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
self._check_properties(trainer, pl_module)
|
||||
|
||||
|
|
|
@ -347,7 +347,9 @@ def test_log_works_in_val_callback(tmpdir):
|
|||
max_epochs=1,
|
||||
callbacks=[cb],
|
||||
)
|
||||
trainer.fit(model)
|
||||
# TODO: Update this test in v1.8 (#11578)
|
||||
with pytest.deprecated_call(match="`Callback.on_epoch_start` hook was deprecated in v1.6"):
|
||||
trainer.fit(model)
|
||||
|
||||
assert cb.call_counter == {
|
||||
"on_validation_batch_end": 4,
|
||||
|
@ -437,12 +439,6 @@ def test_log_works_in_test_callback(tmpdir):
|
|||
def on_test_start(self, _, pl_module):
|
||||
self.make_logging(pl_module, "on_test_start", on_steps=[False], on_epochs=[True], prob_bars=self.choices)
|
||||
|
||||
def on_epoch_start(self, trainer, pl_module):
|
||||
if trainer.testing:
|
||||
self.make_logging(
|
||||
pl_module, "on_epoch_start", on_steps=[False], on_epochs=[True], prob_bars=self.choices
|
||||
)
|
||||
|
||||
def on_test_epoch_start(self, _, pl_module):
|
||||
self.make_logging(
|
||||
pl_module, "on_test_epoch_start", on_steps=[False], on_epochs=[True], prob_bars=self.choices
|
||||
|
|
|
@ -325,7 +325,10 @@ def test_log_works_in_train_callback(tmpdir):
|
|||
max_epochs=1,
|
||||
callbacks=[cb],
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
# TODO: Update this test in v1.8 (#11578)
|
||||
with pytest.deprecated_call(match="`Callback.on_epoch_start` hook was deprecated in v1.6"):
|
||||
trainer.fit(model)
|
||||
|
||||
# Make sure the func_name output equals the average from all logged values when on_epoch true
|
||||
assert trainer.progress_bar_callback.get_metrics(trainer, model)["train_loss"] == model.seen_losses[-1]
|
||||
|
@ -482,11 +485,11 @@ def test_progress_bar_metrics_contains_values_on_train_epoch_end(tmpdir: str):
|
|||
items.pop("v_num", None)
|
||||
return items
|
||||
|
||||
def on_epoch_end(self, trainer: Trainer, model: LightningModule):
|
||||
def on_train_end(self, trainer: Trainer, model: LightningModule):
|
||||
metrics = self.get_metrics(trainer, model)
|
||||
assert metrics["foo"] == self.trainer.current_epoch
|
||||
assert metrics["foo_2"] == self.trainer.current_epoch
|
||||
model.on_epoch_end_called = True
|
||||
model.callback_on_train_end_called = True
|
||||
|
||||
progress_bar = TestProgressBar()
|
||||
trainer = Trainer(
|
||||
|
@ -502,7 +505,7 @@ def test_progress_bar_metrics_contains_values_on_train_epoch_end(tmpdir: str):
|
|||
model = TestModel()
|
||||
trainer.fit(model)
|
||||
assert model.on_train_epoch_end_called
|
||||
assert model.on_epoch_end_called
|
||||
assert model.callback_on_train_end_called
|
||||
|
||||
|
||||
def test_logging_in_callbacks_with_log_function(tmpdir):
|
||||
|
@ -533,7 +536,10 @@ def test_logging_in_callbacks_with_log_function(tmpdir):
|
|||
enable_model_summary=False,
|
||||
callbacks=[LoggingCallback()],
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
# TODO: Update this test in v1.8 (#11578)
|
||||
with pytest.deprecated_call(match="`Callback.on_epoch_end` hook was deprecated in v1.6"):
|
||||
trainer.fit(model)
|
||||
|
||||
expected = {
|
||||
"on_train_start": 1,
|
||||
|
|
|
@ -412,7 +412,7 @@ def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ck
|
|||
num_batches_seen = 0
|
||||
num_on_load_checkpoint_called = 0
|
||||
|
||||
def on_epoch_end(self):
|
||||
def on_train_epoch_end(self):
|
||||
self.num_epochs_end_seen += 1
|
||||
|
||||
def on_train_batch_start(self, *_):
|
||||
|
@ -435,8 +435,7 @@ def test_fit_ckpt_path_epoch_restored(monkeypatch, tmpdir, tmpdir_server, url_ck
|
|||
)
|
||||
trainer.fit(model)
|
||||
|
||||
# `on_epoch_end` will be called once for val_sanity, twice for train, twice for val
|
||||
assert model.num_epochs_end_seen == 1 + 2 + 2
|
||||
assert model.num_epochs_end_seen == 2
|
||||
assert model.num_batches_seen == trainer.num_training_batches * 2
|
||||
assert model.num_on_load_checkpoint_called == 0
|
||||
|
||||
|
@ -1938,7 +1937,7 @@ def test_multiple_trainer_constant_memory_allocated(tmpdir):
|
|||
return torch.optim.Adam(self.layer.parameters(), lr=0.1)
|
||||
|
||||
class Check(Callback):
|
||||
def on_epoch_start(self, trainer, *_):
|
||||
def on_train_epoch_start(self, trainer, *_):
|
||||
assert isinstance(trainer.strategy.model, DistributedDataParallel)
|
||||
|
||||
def current_memory():
|
||||
|
|
Loading…
Reference in New Issue