diff --git a/CHANGELOG.md b/CHANGELOG.md index c17f913cc9..f60d13f493 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139)) +- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) + + + ### Changed - Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 0ba1fd4ff7..db507fa991 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -17,7 +17,7 @@ Abstract base class used to build new callbacks. """ import abc -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from pytorch_lightning.core.lightning import LightningModule @@ -81,7 +81,7 @@ class Callback(abc.ABC): """Called when the train epoch begins.""" pass - def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None: + def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: """Called when the train epoch ends.""" pass @@ -89,7 +89,7 @@ class Callback(abc.ABC): """Called when the val epoch begins.""" pass - def on_validation_epoch_end(self, trainer, pl_module: LightningModule) -> None: + def on_validation_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: """Called when the val epoch ends.""" pass @@ -97,7 +97,7 @@ class Callback(abc.ABC): """Called when the test epoch begins.""" pass - def on_test_epoch_end(self, trainer, pl_module: LightningModule) -> None: + def on_test_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None: """Called when the test epoch ends.""" pass diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 1399d1b3c6..9624f94652 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -240,7 +240,7 @@ class ModelHooks: """ # do something when the epoch starts - def on_train_epoch_end(self, outputs) -> None: + def on_train_epoch_end(self, outputs: List[Any]) -> None: """ Called in the training loop at the very end of the epoch. """ @@ -252,7 +252,7 @@ class ModelHooks: """ # do something when the epoch starts - def on_validation_epoch_end(self) -> None: + def on_validation_epoch_end(self, outputs: List[Any]) -> None: """ Called in the validation loop at the very end of the epoch. """ @@ -264,7 +264,7 @@ class ModelHooks: """ # do something when the epoch starts - def on_test_epoch_end(self) -> None: + def on_test_epoch_end(self, outputs: List[Any]) -> None: """ Called in the test loop at the very end of the epoch. """ diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index b44ba870d9..8823d48a78 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -20,6 +20,10 @@ from typing import Any, Callable, Dict, List, Optional, Type from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature +from pytorch_lightning.utilities.warnings import WarningCache + +warning_cache = WarningCache() class TrainerCallbackHookMixin(ABC): @@ -79,8 +83,12 @@ class TrainerCallbackHookMixin(ABC): for callback in self.callbacks: callback.on_train_epoch_start(self, self.lightning_module) - def on_train_epoch_end(self, outputs): - """Called when the epoch ends.""" + def on_train_epoch_end(self, outputs: List[Any]): + """Called when the epoch ends. + + Args: + outputs: List of outputs on each ``train`` epoch + """ for callback in self.callbacks: callback.on_train_epoch_end(self, self.lightning_module, outputs) @@ -89,20 +97,44 @@ class TrainerCallbackHookMixin(ABC): for callback in self.callbacks: callback.on_validation_epoch_start(self, self.lightning_module) - def on_validation_epoch_end(self): - """Called when the epoch ends.""" + def on_validation_epoch_end(self, outputs: List[Any]): + """Called when the epoch ends. + + Args: + outputs: List of outputs on each ``validation`` epoch + """ for callback in self.callbacks: - callback.on_validation_epoch_end(self, self.lightning_module) + if is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"): + callback.on_validation_epoch_end(self, self.lightning_module, outputs) + else: + warning_cache.warn( + "`Callback.on_validation_epoch_end` signature has changed in v1.3." + " `outputs` parameter has been added." + " Support for the old signature will be removed in v1.5", DeprecationWarning + ) + callback.on_validation_epoch_end(self, self.lightning_module) def on_test_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: callback.on_test_epoch_start(self, self.lightning_module) - def on_test_epoch_end(self): - """Called when the epoch ends.""" + def on_test_epoch_end(self, outputs: List[Any]): + """Called when the epoch ends. + + Args: + outputs: List of outputs on each ``test`` epoch + """ for callback in self.callbacks: - callback.on_test_epoch_end(self, self.lightning_module) + if is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"): + callback.on_test_epoch_end(self, self.lightning_module, outputs) + else: + warning_cache.warn( + "`Callback.on_test_epoch_end` signature has changed in v1.3." + " `outputs` parameter has been added." + " Support for the old signature will be removed in v1.5", DeprecationWarning + ) + callback.on_test_epoch_end(self, self.lightning_module) def on_epoch_start(self): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 91cfc2ec75..20c842939f 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import torch from pytorch_lightning.core.step_result import Result from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.warnings import WarningCache @@ -202,9 +204,6 @@ class EvaluationLoop(object): # with a single dataloader don't pass an array outputs = self.outputs - # free memory - self.outputs = [] - eval_results = outputs if num_dataloaders == 1: eval_results = outputs[0] @@ -313,13 +312,41 @@ class EvaluationLoop(object): def on_evaluation_epoch_end(self, *args, **kwargs): # call the callback hook - if self.trainer.testing: - self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) - else: - self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) + self.call_on_evaluation_epoch_end_hook() self.trainer.call_hook('on_epoch_end') + def call_on_evaluation_epoch_end_hook(self): + outputs = self.outputs + + # free memory + self.outputs = [] + + model_ref = self.trainer.lightning_module + hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end" + + self.trainer._reset_result_and_set_hook_fx_name(hook_name) + + with self.trainer.profiler.profile(hook_name): + + if hasattr(self.trainer, hook_name): + on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name) + on_evaluation_epoch_end_hook(outputs) + + if is_overridden(hook_name, model_ref): + model_hook_fx = getattr(model_ref, hook_name) + if is_param_in_hook_signature(model_hook_fx, "outputs"): + model_hook_fx(outputs) + else: + self.warning_cache.warn( + f"`ModelHooks.{hook_name}` signature has changed in v1.3." + " `outputs` parameter has been added." + " Support for the old signature will be removed in v1.5", DeprecationWarning + ) + model_hook_fx() + + self.trainer._cache_logged_metrics() + def log_evaluation_step_metrics(self, output, batch_idx): if self.trainer.sanity_checking: return diff --git a/pytorch_lightning/utilities/signature_utils.py b/pytorch_lightning/utilities/signature_utils.py new file mode 100644 index 0000000000..546d8e845e --- /dev/null +++ b/pytorch_lightning/utilities/signature_utils.py @@ -0,0 +1,22 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from typing import Callable + + +def is_param_in_hook_signature(hook_fx: Callable, param: str) -> bool: + hook_params = list(inspect.signature(hook_fx).parameters) + if "args" in hook_params or param in hook_params: + return True + return False diff --git a/tests/callbacks/test_callback_hook_outputs.py b/tests/callbacks/test_callback_hook_outputs.py index 78926cc9a7..df0eab31aa 100644 --- a/tests/callbacks/test_callback_hook_outputs.py +++ b/tests/callbacks/test_callback_hook_outputs.py @@ -71,3 +71,66 @@ def test_train_step_no_return(tmpdir, single_cb: bool): results = trainer.fit(model) assert results + + +def test_on_val_epoch_end_outputs(tmpdir): + + class CB(Callback): + + def on_validation_epoch_end(self, trainer, pl_module, outputs): + if trainer.running_sanity_check: + assert len(outputs[0]) == trainer.num_sanity_val_batches[0] + else: + assert len(outputs[0]) == trainer.num_val_batches[0] + + model = BoringModel() + + trainer = Trainer( + callbacks=CB(), + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + weights_summary=None, + ) + + trainer.fit(model) + + +def test_on_test_epoch_end_outputs(tmpdir): + + class CB(Callback): + + def on_test_epoch_end(self, trainer, pl_module, outputs): + assert len(outputs[0]) == trainer.num_test_batches[0] + + model = BoringModel() + + trainer = Trainer( + callbacks=CB(), + default_root_dir=tmpdir, + weights_summary=None, + ) + + trainer.test(model) + + +def test_free_memory_on_eval_outputs(tmpdir): + + class CB(Callback): + + def on_epoch_end(self, trainer, pl_module): + assert len(trainer.evaluation_loop.outputs) == 0 + + model = BoringModel() + + trainer = Trainer( + callbacks=CB(), + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + weights_summary=None, + ) + + trainer.fit(model) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 626eb59dff..608f7bf105 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -56,7 +56,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_validation_epoch_end(trainer, model), + call.on_validation_epoch_end(trainer, model, ANY), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_sanity_check_end(trainer, model), @@ -87,7 +87,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), - call.on_validation_epoch_end(trainer, model), + call.on_validation_epoch_end(trainer, model, ANY), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC @@ -123,7 +123,7 @@ def test_trainer_callback_hook_system_test(tmpdir): call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_test_batch_start(trainer, model, ANY, 1, 0), call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0), - call.on_test_epoch_end(trainer, model), + call.on_test_epoch_end(trainer, model, ANY), call.on_epoch_end(trainer, model), call.on_test_end(trainer, model), call.teardown(trainer, model, 'test'), @@ -156,7 +156,7 @@ def test_trainer_callback_hook_system_validate(tmpdir): call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), call.on_validation_batch_start(trainer, model, ANY, 1, 0), call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0), - call.on_validation_epoch_end(trainer, model), + call.on_validation_epoch_end(trainer, model, ANY), call.on_epoch_end(trainer, model), call.on_validation_end(trainer, model), call.teardown(trainer, model, 'validate'), diff --git a/tests/core/test_hooks.py b/tests/core/test_hooks.py new file mode 100644 index 0000000000..191da0a140 --- /dev/null +++ b/tests/core/test_hooks.py @@ -0,0 +1,56 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning import Trainer +from tests.helpers.boring_model import BoringModel + + +def test_on_val_epoch_end_outputs(tmpdir): + + class TestModel(BoringModel): + + def on_validation_epoch_end(self, outputs): + if trainer.running_sanity_check: + assert len(outputs[0]) == trainer.num_sanity_val_batches[0] + else: + assert len(outputs[0]) == trainer.num_val_batches[0] + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + weights_summary=None, + ) + + trainer.fit(model) + + +def test_on_test_epoch_end_outputs(tmpdir): + + class TestModel(BoringModel): + + def on_test_epoch_end(self, outputs): + assert len(outputs[0]) == trainer.num_test_batches[0] + + model = TestModel() + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=2, + weights_summary=None, + ) + + trainer.test(model) diff --git a/tests/deprecated_api/__init__.py b/tests/deprecated_api/__init__.py index 99e21d1ed6..ccfae3ec8d 100644 --- a/tests/deprecated_api/__init__.py +++ b/tests/deprecated_api/__init__.py @@ -13,9 +13,27 @@ # limitations under the License. """Test deprecated functionality which will be removed in vX.Y.Z""" import sys +from contextlib import contextmanager +from typing import Optional + +import pytest def _soft_unimport_module(str_module): # once the module is imported e.g with parsing with pytest it lives in memory if str_module in sys.modules: del sys.modules[str_module] + + +@contextmanager +def no_deprecated_call(match: Optional[str] = None): + with pytest.warns(None) as record: + yield + try: + w = record.pop(DeprecationWarning) + if match is not None and match not in str(w.message): + return + except AssertionError: + # no DeprecationWarning raised + return + raise AssertionError(f"`DeprecationWarning` was raised: {w}") diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index e65ebbab25..f449a37e33 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -20,6 +20,8 @@ from torch import optim from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.trainer.callback_hook import warning_cache as callback_warning_cache +from tests.deprecated_api import no_deprecated_call from tests.helpers import BoringModel from tests.helpers.utils import no_warning_call @@ -111,3 +113,93 @@ def test_v1_5_0_model_checkpoint_period(tmpdir): ModelCheckpoint(dirpath=tmpdir) with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): ModelCheckpoint(dirpath=tmpdir, period=1) + + +def test_v1_5_0_old_on_validation_epoch_end(tmpdir): + callback_warning_cache.clear() + + class OldSignature(Callback): + + def on_validation_epoch_end(self, trainer, pl_module): # noqa + ... + + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.fit(model) + + class OldSignatureModel(BoringModel): + + def on_validation_epoch_end(self): # noqa + ... + + model = OldSignatureModel() + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.fit(model) + + callback_warning_cache.clear() + + class NewSignature(Callback): + + def on_validation_epoch_end(self, trainer, pl_module, outputs): + ... + + trainer.callbacks = [NewSignature()] + with no_deprecated_call(match="`Callback.on_validation_epoch_end` signature has changed in v1.3."): + trainer.fit(model) + + class NewSignatureModel(BoringModel): + + def on_validation_epoch_end(self, outputs): + ... + + model = NewSignatureModel() + with no_deprecated_call(match="`ModelHooks.on_validation_epoch_end` signature has changed in v1.3."): + trainer.fit(model) + + +def test_v1_5_0_old_on_test_epoch_end(tmpdir): + callback_warning_cache.clear() + + class OldSignature(Callback): + + def on_test_epoch_end(self, trainer, pl_module): # noqa + ... + + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.test(model) + + class OldSignatureModel(BoringModel): + + def on_test_epoch_end(self): # noqa + ... + + model = OldSignatureModel() + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.test(model) + + callback_warning_cache.clear() + + class NewSignature(Callback): + + def on_test_epoch_end(self, trainer, pl_module, outputs): + ... + + trainer.callbacks = [NewSignature()] + with no_deprecated_call(match="`Callback.on_test_epoch_end` signature has changed in v1.3."): + trainer.test(model) + + class NewSignatureModel(BoringModel): + + def on_test_epoch_end(self, outputs): + ... + + model = NewSignatureModel() + with no_deprecated_call(match="`ModelHooks.on_test_epoch_end` signature has changed in v1.3."): + trainer.test(model) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 0d1c7cf40a..69859547f4 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -360,9 +360,9 @@ def test_trainer_model_hook_system(tmpdir): self.called.append(inspect.currentframe().f_code.co_name) super().on_validation_epoch_start() - def on_validation_epoch_end(self): + def on_validation_epoch_end(self, outputs): self.called.append(inspect.currentframe().f_code.co_name) - super().on_validation_epoch_end() + super().on_validation_epoch_end(outputs) def on_test_start(self): self.called.append(inspect.currentframe().f_code.co_name) @@ -380,9 +380,9 @@ def test_trainer_model_hook_system(tmpdir): self.called.append(inspect.currentframe().f_code.co_name) super().on_test_epoch_start() - def on_test_epoch_end(self): + def on_test_epoch_end(self, outputs): self.called.append(inspect.currentframe().f_code.co_name) - super().on_test_epoch_end() + super().on_test_epoch_end(outputs) def on_validation_model_eval(self): self.called.append(inspect.currentframe().f_code.co_name) diff --git a/tests/trainer/logging_/test_eval_loop_logging_1_0.py b/tests/trainer/logging_/test_eval_loop_logging_1_0.py index 72084454ba..e5cf596a78 100644 --- a/tests/trainer/logging_/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_/test_eval_loop_logging_1_0.py @@ -126,7 +126,6 @@ def test__validation_step__step_end__epoch_end__log(tmpdir): def validation_epoch_end(self, outputs): self.log('g', torch.tensor(2, device=self.device), on_epoch=True) self.validation_epoch_end_called = True - assert len(self.trainer.evaluation_loop.outputs) == 0 def backward(self, loss, optimizer, optimizer_idx): return LightningModule.backward(self, loss, optimizer, optimizer_idx) diff --git a/tests/trainer/test_evaluation_loop.py b/tests/trainer/test_evaluation_loop.py new file mode 100644 index 0000000000..3fe58afde7 --- /dev/null +++ b/tests/trainer/test_evaluation_loop.py @@ -0,0 +1,42 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest import mock + +from pytorch_lightning import Trainer +from tests.helpers.boring_model import BoringModel + + +@mock.patch("pytorch_lightning.trainer.evaluation_loop.EvaluationLoop.call_on_evaluation_epoch_end_hook") +def test_call_on_evaluation_epoch_end_hook(eval_epoch_end_mock, tmpdir): + """ + Tests that `call_on_evaluation_epoch_end_hook` is called + for `on_validation_epoch_end` and `on_test_epoch_end` hooks + """ + model = BoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + weights_summary=None, + ) + + trainer.fit(model) + # sanity + 2 epochs + assert eval_epoch_end_mock.call_count == 3 + + trainer.test() + # sanity + 2 epochs + called once for test + assert eval_epoch_end_mock.call_count == 4