diff --git a/CHANGELOG.md b/CHANGELOG.md index 24706f9fed..a93337f8c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -88,9 +88,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for the PyTorch 1.8.1 autograd profiler ([#6618](https://github.com/PyTorchLightning/pytorch-lightning/pull/6618)) -- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) - - - Added `configure_sharded_model` hook ([#6679](https://github.com/PyTorchLightning/pytorch-lightning/pull/6679)) @@ -213,6 +210,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323)) +- Deprecated `outputs` in both `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#7339](https://github.com/PyTorchLightning/pytorch-lightning/pull/7339)) + + - Deprecated `LightningModule.grad_norm` in favor of `pytorch_lightning.utilities.grads.grad_norm` ([#7292](https://github.com/PyTorchLightning/pytorch-lightning/pull/7292)) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 3e8a77cbfd..8283c2ddd7 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -22,7 +22,7 @@ from typing import Any, Dict, List, Optional from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT +from pytorch_lightning.utilities.types import STEP_OUTPUT class Callback(abc.ABC): @@ -108,9 +108,7 @@ class Callback(abc.ABC): """Called when the val epoch begins.""" pass - def on_validation_epoch_end( - self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT - ) -> None: + def on_validation_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the val epoch ends.""" pass @@ -118,7 +116,7 @@ class Callback(abc.ABC): """Called when the test epoch begins.""" pass - def on_test_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT) -> None: + def on_test_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: """Called when the test epoch ends.""" pass diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index bebd1edd8e..d311bd4f58 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -20,7 +20,7 @@ from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn -from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT +from pytorch_lightning.utilities.types import STEP_OUTPUT class ModelHooks: @@ -245,7 +245,7 @@ class ModelHooks: Called in the validation loop at the very beginning of the epoch. """ - def on_validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: + def on_validation_epoch_end(self) -> None: """ Called in the validation loop at the very end of the epoch. """ @@ -255,7 +255,7 @@ class ModelHooks: Called in the test loop at the very beginning of the epoch. """ - def on_test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: + def on_test_epoch_end(self) -> None: """ Called in the test loop at the very end of the epoch. """ diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index fcdd8f55f6..23df26b410 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -111,44 +111,20 @@ class TrainerCallbackHookMixin(ABC): for callback in self.callbacks: callback.on_validation_epoch_start(self, self.lightning_module) - def on_validation_epoch_end(self, outputs: EPOCH_OUTPUT): - """Called when the epoch ends. - - Args: - outputs: List of outputs on each ``validation`` epoch - """ + def on_validation_epoch_end(self): + """Called when the validation epoch ends.""" for callback in self.callbacks: - if is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"): - callback.on_validation_epoch_end(self, self.lightning_module, outputs) - else: - warning_cache.warn( - "`Callback.on_validation_epoch_end` signature has changed in v1.3." - " `outputs` parameter has been added." - " Support for the old signature will be removed in v1.5", DeprecationWarning - ) - callback.on_validation_epoch_end(self, self.lightning_module) + callback.on_validation_epoch_end(self, self.lightning_module) def on_test_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: callback.on_test_epoch_start(self, self.lightning_module) - def on_test_epoch_end(self, outputs: EPOCH_OUTPUT): - """Called when the epoch ends. - - Args: - outputs: List of outputs on each ``test`` epoch - """ + def on_test_epoch_end(self): + """Called when the test epoch ends.""" for callback in self.callbacks: - if is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"): - callback.on_test_epoch_end(self, self.lightning_module, outputs) - else: - warning_cache.warn( - "`Callback.on_test_epoch_end` signature has changed in v1.3." - " `outputs` parameter has been added." - " Support for the old signature will be removed in v1.5", DeprecationWarning - ) - callback.on_test_epoch_end(self, self.lightning_module) + callback.on_test_epoch_end(self, self.lightning_module) def on_predict_epoch_start(self) -> None: """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 8201d700d3..add4a0cbc8 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union from torch.utils.data import DataLoader @@ -20,7 +20,6 @@ from pytorch_lightning.core.step_result import Result from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache @@ -76,6 +75,7 @@ class EvaluationLoop(object): return sum(max_batches) == 0 def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None: + self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end() if self.trainer.testing: self.trainer.call_hook('on_test_start', *args, **kwargs) else: @@ -188,6 +188,13 @@ class EvaluationLoop(object): output = self.trainer.call_hook('validation_step_end', *args, **kwargs) return output + def _should_track_batch_outputs_for_epoch_end(self) -> bool: + model = self.trainer.lightning_module + if self.trainer.testing: + return is_overridden('test_epoch_end', model=model) + else: + return is_overridden('validation_epoch_end', model=model) + def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: # unset dataloder_idx in model self.trainer.logger_connector.evaluation_epoch_end() @@ -241,7 +248,7 @@ class EvaluationLoop(object): # track debug metrics self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output) - def on_evaluation_epoch_end(self, outputs: Union[List[List[Dict]], List[Dict]]) -> None: + def on_evaluation_epoch_end(self) -> None: model_ref = self.trainer.lightning_module hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" @@ -251,18 +258,11 @@ class EvaluationLoop(object): if hasattr(self.trainer, hook_name): on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name) - on_evaluation_epoch_end_hook(outputs) + on_evaluation_epoch_end_hook() if is_overridden(hook_name, model_ref): model_hook_fx = getattr(model_ref, hook_name) - if is_param_in_hook_signature(model_hook_fx, "outputs"): - model_hook_fx(outputs) - else: - self.warning_cache.warn( - f"`ModelHooks.{hook_name}` signature has changed in v1.3. `outputs` parameter has been added." - " Support for the old signature will be removed in v1.5", DeprecationWarning - ) - model_hook_fx() + model_hook_fx() self.trainer._cache_logged_metrics() diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e014671d9c..2a6a53a7c1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -972,7 +972,8 @@ class Trainer( dl_outputs = self.track_output_for_epoch_end(dl_outputs, output) # store batch level output per dataloader - self.evaluation_loop.outputs.append(dl_outputs) + if self.evaluation_loop.should_track_batch_outputs_for_epoch_end: + self.evaluation_loop.outputs.append(dl_outputs) outputs = self.evaluation_loop.outputs @@ -980,14 +981,14 @@ class Trainer( self.evaluation_loop.outputs = [] # with a single dataloader don't pass a 2D list - if self.evaluation_loop.num_dataloaders == 1: + if len(outputs) > 0 and self.evaluation_loop.num_dataloaders == 1: outputs = outputs[0] # lightning module method self.evaluation_loop.evaluation_epoch_end(outputs) # hook - self.evaluation_loop.on_evaluation_epoch_end(outputs) + self.evaluation_loop.on_evaluation_epoch_end() # update epoch-level lr_schedulers if on_epoch: @@ -1212,8 +1213,8 @@ class Trainer( def call_hook(self, hook_name: str, *args, **kwargs) -> Any: # Note this implementation is copy/pasted into the TrainLoop class in TrainLoop._on_train_epoch_end_hook - # This was done to manage the deprecation of an argument to on_train_epoch_end - # If making chnages to this function, ensure that those changes are also made to + # This was done to manage the deprecation of the `outputs` argument to on_train_epoch_end + # If making changes to this function, ensure that those changes are also made to # TrainLoop._on_train_epoch_end_hook # set hook_name to model + reset Result obj diff --git a/tests/callbacks/test_callback_hook_outputs.py b/tests/callbacks/test_callback_hook_outputs.py index b2aa20af57..36322482c5 100644 --- a/tests/callbacks/test_callback_hook_outputs.py +++ b/tests/callbacks/test_callback_hook_outputs.py @@ -65,48 +65,6 @@ def test_train_step_no_return(tmpdir, single_cb: bool): trainer.fit(model) -def test_on_val_epoch_end_outputs(tmpdir): - - class CB(Callback): - - def on_validation_epoch_end(self, trainer, pl_module, outputs): - if trainer.running_sanity_check: - assert len(outputs) == trainer.num_sanity_val_batches[0] - else: - assert len(outputs) == trainer.num_val_batches[0] - - model = BoringModel() - - trainer = Trainer( - callbacks=CB(), - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1, - weights_summary=None, - ) - - trainer.fit(model) - - -def test_on_test_epoch_end_outputs(tmpdir): - - class CB(Callback): - - def on_test_epoch_end(self, trainer, pl_module, outputs): - assert len(outputs) == trainer.num_test_batches[0] - - model = BoringModel() - - trainer = Trainer( - callbacks=CB(), - default_root_dir=tmpdir, - weights_summary=None, - ) - - trainer.test(model) - - def test_free_memory_on_eval_outputs(tmpdir): class CB(Callback): diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index a30b4fe0f6..9b048e022c 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -58,7 +58,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_validation_epoch_end(trainer, model, ANY), + call.on_validation_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_sanity_check_end(trainer, model), @@ -90,7 +90,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_validation_epoch_end(trainer, model, ANY), + call.on_validation_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC @@ -128,7 +128,7 @@ def test_trainer_callback_hook_system_test(tmpdir): call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_test_batch_start(trainer, model, ANY, 1, 0), call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0), - call.on_test_epoch_end(trainer, model, ANY), + call.on_test_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_test_end(trainer, model), call.teardown(trainer, model, 'test'), @@ -163,7 +163,7 @@ def test_trainer_callback_hook_system_validate(tmpdir): call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_validation_batch_start(trainer, model, ANY, 1, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0), - call.on_validation_epoch_end(trainer, model, ANY), + call.on_validation_epoch_end(trainer, model), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.teardown(trainer, model, 'validate'), diff --git a/tests/core/test_hooks.py b/tests/core/test_hooks.py deleted file mode 100644 index 087f884d96..0000000000 --- a/tests/core/test_hooks.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from pytorch_lightning import Trainer -from tests.helpers.boring_model import BoringModel - - -def test_on_val_epoch_end_outputs(tmpdir): - - class TestModel(BoringModel): - - def on_validation_epoch_end(self, outputs): - if trainer.running_sanity_check: - assert len(outputs) == trainer.num_sanity_val_batches[0] - else: - assert len(outputs) == trainer.num_val_batches[0] - - model = TestModel() - - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1, - weights_summary=None, - ) - - trainer.fit(model) - - -def test_on_test_epoch_end_outputs(tmpdir): - - class TestModel(BoringModel): - - def on_test_epoch_end(self, outputs): - assert len(outputs) == trainer.num_test_batches[0] - - model = TestModel() - - trainer = Trainer( - default_root_dir=tmpdir, - fast_dev_run=2, - weights_summary=None, - ) - - trainer.test(model) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index d49e191e69..91b93a88f0 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -263,96 +263,6 @@ def test_v1_5_0_old_on_train_epoch_end(tmpdir): trainer.fit(model) -def test_v1_5_0_old_on_validation_epoch_end(tmpdir): - callback_warning_cache.clear() - - class OldSignature(Callback): - - def on_validation_epoch_end(self, trainer, pl_module): # noqa - ... - - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) - - with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer.fit(model) - - class OldSignatureModel(BoringModel): - - def on_validation_epoch_end(self): # noqa - ... - - model = OldSignatureModel() - - with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer.fit(model) - - callback_warning_cache.clear() - - class NewSignature(Callback): - - def on_validation_epoch_end(self, trainer, pl_module, outputs): - ... - - trainer.callbacks = [NewSignature()] - with no_deprecated_call(match="`Callback.on_validation_epoch_end` signature has changed in v1.3."): - trainer.fit(model) - - class NewSignatureModel(BoringModel): - - def on_validation_epoch_end(self, outputs): - ... - - model = NewSignatureModel() - with no_deprecated_call(match="`ModelHooks.on_validation_epoch_end` signature has changed in v1.3."): - trainer.fit(model) - - -def test_v1_5_0_old_on_test_epoch_end(tmpdir): - callback_warning_cache.clear() - - class OldSignature(Callback): - - def on_test_epoch_end(self, trainer, pl_module): # noqa - ... - - model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) - - with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer.test(model) - - class OldSignatureModel(BoringModel): - - def on_test_epoch_end(self): # noqa - ... - - model = OldSignatureModel() - - with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer.test(model) - - callback_warning_cache.clear() - - class NewSignature(Callback): - - def on_test_epoch_end(self, trainer, pl_module, outputs): - ... - - trainer.callbacks = [NewSignature()] - with no_deprecated_call(match="`Callback.on_test_epoch_end` signature has changed in v1.3."): - trainer.test(model) - - class NewSignatureModel(BoringModel): - - def on_test_epoch_end(self, outputs): - ... - - model = NewSignatureModel() - with no_deprecated_call(match="`ModelHooks.on_test_epoch_end` signature has changed in v1.3."): - trainer.test(model) - - @pytest.mark.parametrize("cls", (BaseProfiler, SimpleProfiler, AdvancedProfiler, PyTorchProfiler)) def test_v1_5_0_profiler_output_filename(tmpdir, cls): filepath = str(tmpdir / "test.txt") diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 06eaca6d61..ab1ce3367c 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -681,10 +681,10 @@ def test_metrics_reset(tmpdir): def on_train_epoch_end(self): self._assert_epoch_end('train') - def on_validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): self._assert_epoch_end('val') - def on_test_epoch_end(self, outputs): + def on_test_epoch_end(self): self._assert_epoch_end('test') def _assert_called(model, stage):