Add outputs param for `on_val/test_epoch_end` hooks (#6120)

* add outputs param for on_val/test_epoch_end hooks

* update changelog

* fix warning message

* add custom call hook

* cache logged metrics

* add args to docstrings

* use warning cache

* add utility method for param in sig check

* Update CHANGELOG.md

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* update docstring

* add test for eval epoch end hook

* add types and replace model ref

* add deprecation test

* fix test fx name

* add model hooks warning

* add old signature model to tests

* add clear warning cache

* sopport args param

* update tests

* add tests for model hooks

* code suggestions

* add signature utils

* fix pep8 issues

* fix pep8 issues

* fix outputs issue

* fix tests

* code fixes

* fix validate test

* test

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Kaushik B 2021-03-16 21:45:16 +05:30 committed by GitHub
parent 555a6fea21
commit b190403e28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 386 additions and 31 deletions

View File

@ -40,6 +40,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))
- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))
### Changed
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))

View File

@ -17,7 +17,7 @@ Abstract base class used to build new callbacks.
"""
import abc
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from pytorch_lightning.core.lightning import LightningModule
@ -81,7 +81,7 @@ class Callback(abc.ABC):
"""Called when the train epoch begins."""
pass
def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None:
def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None:
"""Called when the train epoch ends."""
pass
@ -89,7 +89,7 @@ class Callback(abc.ABC):
"""Called when the val epoch begins."""
pass
def on_validation_epoch_end(self, trainer, pl_module: LightningModule) -> None:
def on_validation_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None:
"""Called when the val epoch ends."""
pass
@ -97,7 +97,7 @@ class Callback(abc.ABC):
"""Called when the test epoch begins."""
pass
def on_test_epoch_end(self, trainer, pl_module: LightningModule) -> None:
def on_test_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None:
"""Called when the test epoch ends."""
pass

View File

@ -240,7 +240,7 @@ class ModelHooks:
"""
# do something when the epoch starts
def on_train_epoch_end(self, outputs) -> None:
def on_train_epoch_end(self, outputs: List[Any]) -> None:
"""
Called in the training loop at the very end of the epoch.
"""
@ -252,7 +252,7 @@ class ModelHooks:
"""
# do something when the epoch starts
def on_validation_epoch_end(self) -> None:
def on_validation_epoch_end(self, outputs: List[Any]) -> None:
"""
Called in the validation loop at the very end of the epoch.
"""
@ -264,7 +264,7 @@ class ModelHooks:
"""
# do something when the epoch starts
def on_test_epoch_end(self) -> None:
def on_test_epoch_end(self, outputs: List[Any]) -> None:
"""
Called in the test loop at the very end of the epoch.
"""

View File

@ -20,6 +20,10 @@ from typing import Any, Callable, Dict, List, Optional, Type
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import WarningCache
warning_cache = WarningCache()
class TrainerCallbackHookMixin(ABC):
@ -79,8 +83,12 @@ class TrainerCallbackHookMixin(ABC):
for callback in self.callbacks:
callback.on_train_epoch_start(self, self.lightning_module)
def on_train_epoch_end(self, outputs):
"""Called when the epoch ends."""
def on_train_epoch_end(self, outputs: List[Any]):
"""Called when the epoch ends.
Args:
outputs: List of outputs on each ``train`` epoch
"""
for callback in self.callbacks:
callback.on_train_epoch_end(self, self.lightning_module, outputs)
@ -89,20 +97,44 @@ class TrainerCallbackHookMixin(ABC):
for callback in self.callbacks:
callback.on_validation_epoch_start(self, self.lightning_module)
def on_validation_epoch_end(self):
"""Called when the epoch ends."""
def on_validation_epoch_end(self, outputs: List[Any]):
"""Called when the epoch ends.
Args:
outputs: List of outputs on each ``validation`` epoch
"""
for callback in self.callbacks:
callback.on_validation_epoch_end(self, self.lightning_module)
if is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"):
callback.on_validation_epoch_end(self, self.lightning_module, outputs)
else:
warning_cache.warn(
"`Callback.on_validation_epoch_end` signature has changed in v1.3."
" `outputs` parameter has been added."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
callback.on_validation_epoch_end(self, self.lightning_module)
def on_test_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_test_epoch_start(self, self.lightning_module)
def on_test_epoch_end(self):
"""Called when the epoch ends."""
def on_test_epoch_end(self, outputs: List[Any]):
"""Called when the epoch ends.
Args:
outputs: List of outputs on each ``test`` epoch
"""
for callback in self.callbacks:
callback.on_test_epoch_end(self, self.lightning_module)
if is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"):
callback.on_test_epoch_end(self, self.lightning_module, outputs)
else:
warning_cache.warn(
"`Callback.on_test_epoch_end` signature has changed in v1.3."
" `outputs` parameter has been added."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
callback.on_test_epoch_end(self, self.lightning_module)
def on_epoch_start(self):
"""Called when the epoch begins."""

View File

@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import WarningCache
@ -202,9 +204,6 @@ class EvaluationLoop(object):
# with a single dataloader don't pass an array
outputs = self.outputs
# free memory
self.outputs = []
eval_results = outputs
if num_dataloaders == 1:
eval_results = outputs[0]
@ -313,13 +312,41 @@ class EvaluationLoop(object):
def on_evaluation_epoch_end(self, *args, **kwargs):
# call the callback hook
if self.trainer.testing:
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)
self.call_on_evaluation_epoch_end_hook()
self.trainer.call_hook('on_epoch_end')
def call_on_evaluation_epoch_end_hook(self):
outputs = self.outputs
# free memory
self.outputs = []
model_ref = self.trainer.lightning_module
hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"
self.trainer._reset_result_and_set_hook_fx_name(hook_name)
with self.trainer.profiler.profile(hook_name):
if hasattr(self.trainer, hook_name):
on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name)
on_evaluation_epoch_end_hook(outputs)
if is_overridden(hook_name, model_ref):
model_hook_fx = getattr(model_ref, hook_name)
if is_param_in_hook_signature(model_hook_fx, "outputs"):
model_hook_fx(outputs)
else:
self.warning_cache.warn(
f"`ModelHooks.{hook_name}` signature has changed in v1.3."
" `outputs` parameter has been added."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
model_hook_fx()
self.trainer._cache_logged_metrics()
def log_evaluation_step_metrics(self, output, batch_idx):
if self.trainer.sanity_checking:
return

View File

@ -0,0 +1,22 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable
def is_param_in_hook_signature(hook_fx: Callable, param: str) -> bool:
hook_params = list(inspect.signature(hook_fx).parameters)
if "args" in hook_params or param in hook_params:
return True
return False

View File

@ -71,3 +71,66 @@ def test_train_step_no_return(tmpdir, single_cb: bool):
results = trainer.fit(model)
assert results
def test_on_val_epoch_end_outputs(tmpdir):
class CB(Callback):
def on_validation_epoch_end(self, trainer, pl_module, outputs):
if trainer.running_sanity_check:
assert len(outputs[0]) == trainer.num_sanity_val_batches[0]
else:
assert len(outputs[0]) == trainer.num_val_batches[0]
model = BoringModel()
trainer = Trainer(
callbacks=CB(),
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
weights_summary=None,
)
trainer.fit(model)
def test_on_test_epoch_end_outputs(tmpdir):
class CB(Callback):
def on_test_epoch_end(self, trainer, pl_module, outputs):
assert len(outputs[0]) == trainer.num_test_batches[0]
model = BoringModel()
trainer = Trainer(
callbacks=CB(),
default_root_dir=tmpdir,
weights_summary=None,
)
trainer.test(model)
def test_free_memory_on_eval_outputs(tmpdir):
class CB(Callback):
def on_epoch_end(self, trainer, pl_module):
assert len(trainer.evaluation_loop.outputs) == 0
model = BoringModel()
trainer = Trainer(
callbacks=CB(),
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
weights_summary=None,
)
trainer.fit(model)

View File

@ -56,7 +56,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_epoch_end(trainer, model),
call.on_validation_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_sanity_check_end(trainer, model),
@ -87,7 +87,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_epoch_end(trainer, model),
call.on_validation_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
@ -123,7 +123,7 @@ def test_trainer_callback_hook_system_test(tmpdir):
call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_test_batch_start(trainer, model, ANY, 1, 0),
call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0),
call.on_test_epoch_end(trainer, model),
call.on_test_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_test_end(trainer, model),
call.teardown(trainer, model, 'test'),
@ -156,7 +156,7 @@ def test_trainer_callback_hook_system_validate(tmpdir):
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_batch_start(trainer, model, ANY, 1, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0),
call.on_validation_epoch_end(trainer, model),
call.on_validation_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.teardown(trainer, model, 'validate'),

56
tests/core/test_hooks.py Normal file
View File

@ -0,0 +1,56 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Trainer
from tests.helpers.boring_model import BoringModel
def test_on_val_epoch_end_outputs(tmpdir):
class TestModel(BoringModel):
def on_validation_epoch_end(self, outputs):
if trainer.running_sanity_check:
assert len(outputs[0]) == trainer.num_sanity_val_batches[0]
else:
assert len(outputs[0]) == trainer.num_val_batches[0]
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
weights_summary=None,
)
trainer.fit(model)
def test_on_test_epoch_end_outputs(tmpdir):
class TestModel(BoringModel):
def on_test_epoch_end(self, outputs):
assert len(outputs[0]) == trainer.num_test_batches[0]
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=2,
weights_summary=None,
)
trainer.test(model)

View File

@ -13,9 +13,27 @@
# limitations under the License.
"""Test deprecated functionality which will be removed in vX.Y.Z"""
import sys
from contextlib import contextmanager
from typing import Optional
import pytest
def _soft_unimport_module(str_module):
# once the module is imported e.g with parsing with pytest it lives in memory
if str_module in sys.modules:
del sys.modules[str_module]
@contextmanager
def no_deprecated_call(match: Optional[str] = None):
with pytest.warns(None) as record:
yield
try:
w = record.pop(DeprecationWarning)
if match is not None and match not in str(w.message):
return
except AssertionError:
# no DeprecationWarning raised
return
raise AssertionError(f"`DeprecationWarning` was raised: {w}")

View File

@ -20,6 +20,8 @@ from torch import optim
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.trainer.callback_hook import warning_cache as callback_warning_cache
from tests.deprecated_api import no_deprecated_call
from tests.helpers import BoringModel
from tests.helpers.utils import no_warning_call
@ -111,3 +113,93 @@ def test_v1_5_0_model_checkpoint_period(tmpdir):
ModelCheckpoint(dirpath=tmpdir)
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
ModelCheckpoint(dirpath=tmpdir, period=1)
def test_v1_5_0_old_on_validation_epoch_end(tmpdir):
callback_warning_cache.clear()
class OldSignature(Callback):
def on_validation_epoch_end(self, trainer, pl_module): # noqa
...
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature())
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.fit(model)
class OldSignatureModel(BoringModel):
def on_validation_epoch_end(self): # noqa
...
model = OldSignatureModel()
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.fit(model)
callback_warning_cache.clear()
class NewSignature(Callback):
def on_validation_epoch_end(self, trainer, pl_module, outputs):
...
trainer.callbacks = [NewSignature()]
with no_deprecated_call(match="`Callback.on_validation_epoch_end` signature has changed in v1.3."):
trainer.fit(model)
class NewSignatureModel(BoringModel):
def on_validation_epoch_end(self, outputs):
...
model = NewSignatureModel()
with no_deprecated_call(match="`ModelHooks.on_validation_epoch_end` signature has changed in v1.3."):
trainer.fit(model)
def test_v1_5_0_old_on_test_epoch_end(tmpdir):
callback_warning_cache.clear()
class OldSignature(Callback):
def on_test_epoch_end(self, trainer, pl_module): # noqa
...
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature())
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.test(model)
class OldSignatureModel(BoringModel):
def on_test_epoch_end(self): # noqa
...
model = OldSignatureModel()
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.test(model)
callback_warning_cache.clear()
class NewSignature(Callback):
def on_test_epoch_end(self, trainer, pl_module, outputs):
...
trainer.callbacks = [NewSignature()]
with no_deprecated_call(match="`Callback.on_test_epoch_end` signature has changed in v1.3."):
trainer.test(model)
class NewSignatureModel(BoringModel):
def on_test_epoch_end(self, outputs):
...
model = NewSignatureModel()
with no_deprecated_call(match="`ModelHooks.on_test_epoch_end` signature has changed in v1.3."):
trainer.test(model)

View File

@ -360,9 +360,9 @@ def test_trainer_model_hook_system(tmpdir):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_validation_epoch_start()
def on_validation_epoch_end(self):
def on_validation_epoch_end(self, outputs):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_validation_epoch_end()
super().on_validation_epoch_end(outputs)
def on_test_start(self):
self.called.append(inspect.currentframe().f_code.co_name)
@ -380,9 +380,9 @@ def test_trainer_model_hook_system(tmpdir):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_epoch_start()
def on_test_epoch_end(self):
def on_test_epoch_end(self, outputs):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_epoch_end()
super().on_test_epoch_end(outputs)
def on_validation_model_eval(self):
self.called.append(inspect.currentframe().f_code.co_name)

View File

@ -126,7 +126,6 @@ def test__validation_step__step_end__epoch_end__log(tmpdir):
def validation_epoch_end(self, outputs):
self.log('g', torch.tensor(2, device=self.device), on_epoch=True)
self.validation_epoch_end_called = True
assert len(self.trainer.evaluation_loop.outputs) == 0
def backward(self, loss, optimizer, optimizer_idx):
return LightningModule.backward(self, loss, optimizer, optimizer_idx)

View File

@ -0,0 +1,42 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from pytorch_lightning import Trainer
from tests.helpers.boring_model import BoringModel
@mock.patch("pytorch_lightning.trainer.evaluation_loop.EvaluationLoop.call_on_evaluation_epoch_end_hook")
def test_call_on_evaluation_epoch_end_hook(eval_epoch_end_mock, tmpdir):
"""
Tests that `call_on_evaluation_epoch_end_hook` is called
for `on_validation_epoch_end` and `on_test_epoch_end` hooks
"""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=2,
weights_summary=None,
)
trainer.fit(model)
# sanity + 2 epochs
assert eval_epoch_end_mock.call_count == 3
trainer.test()
# sanity + 2 epochs + called once for test
assert eval_epoch_end_mock.call_count == 4