Add outputs param for `on_val/test_epoch_end` hooks (#6120)
* add outputs param for on_val/test_epoch_end hooks * update changelog * fix warning message * add custom call hook * cache logged metrics * add args to docstrings * use warning cache * add utility method for param in sig check * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update docstring * add test for eval epoch end hook * add types and replace model ref * add deprecation test * fix test fx name * add model hooks warning * add old signature model to tests * add clear warning cache * sopport args param * update tests * add tests for model hooks * code suggestions * add signature utils * fix pep8 issues * fix pep8 issues * fix outputs issue * fix tests * code fixes * fix validate test * test Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
555a6fea21
commit
b190403e28
|
@ -40,6 +40,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))
|
||||
|
||||
|
||||
- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))
|
||||
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
|
||||
|
|
|
@ -17,7 +17,7 @@ Abstract base class used to build new callbacks.
|
|||
"""
|
||||
|
||||
import abc
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
||||
|
@ -81,7 +81,7 @@ class Callback(abc.ABC):
|
|||
"""Called when the train epoch begins."""
|
||||
pass
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: Any) -> None:
|
||||
def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None:
|
||||
"""Called when the train epoch ends."""
|
||||
pass
|
||||
|
||||
|
@ -89,7 +89,7 @@ class Callback(abc.ABC):
|
|||
"""Called when the val epoch begins."""
|
||||
pass
|
||||
|
||||
def on_validation_epoch_end(self, trainer, pl_module: LightningModule) -> None:
|
||||
def on_validation_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None:
|
||||
"""Called when the val epoch ends."""
|
||||
pass
|
||||
|
||||
|
@ -97,7 +97,7 @@ class Callback(abc.ABC):
|
|||
"""Called when the test epoch begins."""
|
||||
pass
|
||||
|
||||
def on_test_epoch_end(self, trainer, pl_module: LightningModule) -> None:
|
||||
def on_test_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[Any]) -> None:
|
||||
"""Called when the test epoch ends."""
|
||||
pass
|
||||
|
||||
|
|
|
@ -240,7 +240,7 @@ class ModelHooks:
|
|||
"""
|
||||
# do something when the epoch starts
|
||||
|
||||
def on_train_epoch_end(self, outputs) -> None:
|
||||
def on_train_epoch_end(self, outputs: List[Any]) -> None:
|
||||
"""
|
||||
Called in the training loop at the very end of the epoch.
|
||||
"""
|
||||
|
@ -252,7 +252,7 @@ class ModelHooks:
|
|||
"""
|
||||
# do something when the epoch starts
|
||||
|
||||
def on_validation_epoch_end(self) -> None:
|
||||
def on_validation_epoch_end(self, outputs: List[Any]) -> None:
|
||||
"""
|
||||
Called in the validation loop at the very end of the epoch.
|
||||
"""
|
||||
|
@ -264,7 +264,7 @@ class ModelHooks:
|
|||
"""
|
||||
# do something when the epoch starts
|
||||
|
||||
def on_test_epoch_end(self) -> None:
|
||||
def on_test_epoch_end(self, outputs: List[Any]) -> None:
|
||||
"""
|
||||
Called in the test loop at the very end of the epoch.
|
||||
"""
|
||||
|
|
|
@ -20,6 +20,10 @@ from typing import Any, Callable, Dict, List, Optional, Type
|
|||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
warning_cache = WarningCache()
|
||||
|
||||
|
||||
class TrainerCallbackHookMixin(ABC):
|
||||
|
@ -79,8 +83,12 @@ class TrainerCallbackHookMixin(ABC):
|
|||
for callback in self.callbacks:
|
||||
callback.on_train_epoch_start(self, self.lightning_module)
|
||||
|
||||
def on_train_epoch_end(self, outputs):
|
||||
"""Called when the epoch ends."""
|
||||
def on_train_epoch_end(self, outputs: List[Any]):
|
||||
"""Called when the epoch ends.
|
||||
|
||||
Args:
|
||||
outputs: List of outputs on each ``train`` epoch
|
||||
"""
|
||||
for callback in self.callbacks:
|
||||
callback.on_train_epoch_end(self, self.lightning_module, outputs)
|
||||
|
||||
|
@ -89,20 +97,44 @@ class TrainerCallbackHookMixin(ABC):
|
|||
for callback in self.callbacks:
|
||||
callback.on_validation_epoch_start(self, self.lightning_module)
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
"""Called when the epoch ends."""
|
||||
def on_validation_epoch_end(self, outputs: List[Any]):
|
||||
"""Called when the epoch ends.
|
||||
|
||||
Args:
|
||||
outputs: List of outputs on each ``validation`` epoch
|
||||
"""
|
||||
for callback in self.callbacks:
|
||||
callback.on_validation_epoch_end(self, self.lightning_module)
|
||||
if is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"):
|
||||
callback.on_validation_epoch_end(self, self.lightning_module, outputs)
|
||||
else:
|
||||
warning_cache.warn(
|
||||
"`Callback.on_validation_epoch_end` signature has changed in v1.3."
|
||||
" `outputs` parameter has been added."
|
||||
" Support for the old signature will be removed in v1.5", DeprecationWarning
|
||||
)
|
||||
callback.on_validation_epoch_end(self, self.lightning_module)
|
||||
|
||||
def on_test_epoch_start(self):
|
||||
"""Called when the epoch begins."""
|
||||
for callback in self.callbacks:
|
||||
callback.on_test_epoch_start(self, self.lightning_module)
|
||||
|
||||
def on_test_epoch_end(self):
|
||||
"""Called when the epoch ends."""
|
||||
def on_test_epoch_end(self, outputs: List[Any]):
|
||||
"""Called when the epoch ends.
|
||||
|
||||
Args:
|
||||
outputs: List of outputs on each ``test`` epoch
|
||||
"""
|
||||
for callback in self.callbacks:
|
||||
callback.on_test_epoch_end(self, self.lightning_module)
|
||||
if is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"):
|
||||
callback.on_test_epoch_end(self, self.lightning_module, outputs)
|
||||
else:
|
||||
warning_cache.warn(
|
||||
"`Callback.on_test_epoch_end` signature has changed in v1.3."
|
||||
" `outputs` parameter has been added."
|
||||
" Support for the old signature will be removed in v1.5", DeprecationWarning
|
||||
)
|
||||
callback.on_test_epoch_end(self, self.lightning_module)
|
||||
|
||||
def on_epoch_start(self):
|
||||
"""Called when the epoch begins."""
|
||||
|
|
|
@ -11,12 +11,14 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.trainer.supporters import PredictionCollection
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
|
||||
|
@ -202,9 +204,6 @@ class EvaluationLoop(object):
|
|||
# with a single dataloader don't pass an array
|
||||
outputs = self.outputs
|
||||
|
||||
# free memory
|
||||
self.outputs = []
|
||||
|
||||
eval_results = outputs
|
||||
if num_dataloaders == 1:
|
||||
eval_results = outputs[0]
|
||||
|
@ -313,13 +312,41 @@ class EvaluationLoop(object):
|
|||
|
||||
def on_evaluation_epoch_end(self, *args, **kwargs):
|
||||
# call the callback hook
|
||||
if self.trainer.testing:
|
||||
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)
|
||||
self.call_on_evaluation_epoch_end_hook()
|
||||
|
||||
self.trainer.call_hook('on_epoch_end')
|
||||
|
||||
def call_on_evaluation_epoch_end_hook(self):
|
||||
outputs = self.outputs
|
||||
|
||||
# free memory
|
||||
self.outputs = []
|
||||
|
||||
model_ref = self.trainer.lightning_module
|
||||
hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"
|
||||
|
||||
self.trainer._reset_result_and_set_hook_fx_name(hook_name)
|
||||
|
||||
with self.trainer.profiler.profile(hook_name):
|
||||
|
||||
if hasattr(self.trainer, hook_name):
|
||||
on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name)
|
||||
on_evaluation_epoch_end_hook(outputs)
|
||||
|
||||
if is_overridden(hook_name, model_ref):
|
||||
model_hook_fx = getattr(model_ref, hook_name)
|
||||
if is_param_in_hook_signature(model_hook_fx, "outputs"):
|
||||
model_hook_fx(outputs)
|
||||
else:
|
||||
self.warning_cache.warn(
|
||||
f"`ModelHooks.{hook_name}` signature has changed in v1.3."
|
||||
" `outputs` parameter has been added."
|
||||
" Support for the old signature will be removed in v1.5", DeprecationWarning
|
||||
)
|
||||
model_hook_fx()
|
||||
|
||||
self.trainer._cache_logged_metrics()
|
||||
|
||||
def log_evaluation_step_metrics(self, output, batch_idx):
|
||||
if self.trainer.sanity_checking:
|
||||
return
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def is_param_in_hook_signature(hook_fx: Callable, param: str) -> bool:
|
||||
hook_params = list(inspect.signature(hook_fx).parameters)
|
||||
if "args" in hook_params or param in hook_params:
|
||||
return True
|
||||
return False
|
|
@ -71,3 +71,66 @@ def test_train_step_no_return(tmpdir, single_cb: bool):
|
|||
|
||||
results = trainer.fit(model)
|
||||
assert results
|
||||
|
||||
|
||||
def test_on_val_epoch_end_outputs(tmpdir):
|
||||
|
||||
class CB(Callback):
|
||||
|
||||
def on_validation_epoch_end(self, trainer, pl_module, outputs):
|
||||
if trainer.running_sanity_check:
|
||||
assert len(outputs[0]) == trainer.num_sanity_val_batches[0]
|
||||
else:
|
||||
assert len(outputs[0]) == trainer.num_val_batches[0]
|
||||
|
||||
model = BoringModel()
|
||||
|
||||
trainer = Trainer(
|
||||
callbacks=CB(),
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_on_test_epoch_end_outputs(tmpdir):
|
||||
|
||||
class CB(Callback):
|
||||
|
||||
def on_test_epoch_end(self, trainer, pl_module, outputs):
|
||||
assert len(outputs[0]) == trainer.num_test_batches[0]
|
||||
|
||||
model = BoringModel()
|
||||
|
||||
trainer = Trainer(
|
||||
callbacks=CB(),
|
||||
default_root_dir=tmpdir,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
trainer.test(model)
|
||||
|
||||
|
||||
def test_free_memory_on_eval_outputs(tmpdir):
|
||||
|
||||
class CB(Callback):
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
assert len(trainer.evaluation_loop.outputs) == 0
|
||||
|
||||
model = BoringModel()
|
||||
|
||||
trainer = Trainer(
|
||||
callbacks=CB(),
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
|
|
@ -56,7 +56,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
|
|||
call.on_validation_epoch_start(trainer, model),
|
||||
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
|
||||
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
|
||||
call.on_validation_epoch_end(trainer, model),
|
||||
call.on_validation_epoch_end(trainer, model, ANY),
|
||||
call.on_epoch_end(trainer, model),
|
||||
call.on_validation_end(trainer, model),
|
||||
call.on_sanity_check_end(trainer, model),
|
||||
|
@ -87,7 +87,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
|
|||
call.on_validation_epoch_start(trainer, model),
|
||||
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
|
||||
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
|
||||
call.on_validation_epoch_end(trainer, model),
|
||||
call.on_validation_epoch_end(trainer, model, ANY),
|
||||
call.on_epoch_end(trainer, model),
|
||||
call.on_validation_end(trainer, model),
|
||||
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
|
||||
|
@ -123,7 +123,7 @@ def test_trainer_callback_hook_system_test(tmpdir):
|
|||
call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0),
|
||||
call.on_test_batch_start(trainer, model, ANY, 1, 0),
|
||||
call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0),
|
||||
call.on_test_epoch_end(trainer, model),
|
||||
call.on_test_epoch_end(trainer, model, ANY),
|
||||
call.on_epoch_end(trainer, model),
|
||||
call.on_test_end(trainer, model),
|
||||
call.teardown(trainer, model, 'test'),
|
||||
|
@ -156,7 +156,7 @@ def test_trainer_callback_hook_system_validate(tmpdir):
|
|||
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
|
||||
call.on_validation_batch_start(trainer, model, ANY, 1, 0),
|
||||
call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0),
|
||||
call.on_validation_epoch_end(trainer, model),
|
||||
call.on_validation_epoch_end(trainer, model, ANY),
|
||||
call.on_epoch_end(trainer, model),
|
||||
call.on_validation_end(trainer, model),
|
||||
call.teardown(trainer, model, 'validate'),
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning import Trainer
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
|
||||
|
||||
def test_on_val_epoch_end_outputs(tmpdir):
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def on_validation_epoch_end(self, outputs):
|
||||
if trainer.running_sanity_check:
|
||||
assert len(outputs[0]) == trainer.num_sanity_val_batches[0]
|
||||
else:
|
||||
assert len(outputs[0]) == trainer.num_val_batches[0]
|
||||
|
||||
model = TestModel()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
max_epochs=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_on_test_epoch_end_outputs(tmpdir):
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def on_test_epoch_end(self, outputs):
|
||||
assert len(outputs[0]) == trainer.num_test_batches[0]
|
||||
|
||||
model = TestModel()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=2,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
trainer.test(model)
|
|
@ -13,9 +13,27 @@
|
|||
# limitations under the License.
|
||||
"""Test deprecated functionality which will be removed in vX.Y.Z"""
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _soft_unimport_module(str_module):
|
||||
# once the module is imported e.g with parsing with pytest it lives in memory
|
||||
if str_module in sys.modules:
|
||||
del sys.modules[str_module]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def no_deprecated_call(match: Optional[str] = None):
|
||||
with pytest.warns(None) as record:
|
||||
yield
|
||||
try:
|
||||
w = record.pop(DeprecationWarning)
|
||||
if match is not None and match not in str(w.message):
|
||||
return
|
||||
except AssertionError:
|
||||
# no DeprecationWarning raised
|
||||
return
|
||||
raise AssertionError(f"`DeprecationWarning` was raised: {w}")
|
||||
|
|
|
@ -20,6 +20,8 @@ from torch import optim
|
|||
from pytorch_lightning import Callback, Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from pytorch_lightning.trainer.callback_hook import warning_cache as callback_warning_cache
|
||||
from tests.deprecated_api import no_deprecated_call
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.utils import no_warning_call
|
||||
|
||||
|
@ -111,3 +113,93 @@ def test_v1_5_0_model_checkpoint_period(tmpdir):
|
|||
ModelCheckpoint(dirpath=tmpdir)
|
||||
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
|
||||
ModelCheckpoint(dirpath=tmpdir, period=1)
|
||||
|
||||
|
||||
def test_v1_5_0_old_on_validation_epoch_end(tmpdir):
|
||||
callback_warning_cache.clear()
|
||||
|
||||
class OldSignature(Callback):
|
||||
|
||||
def on_validation_epoch_end(self, trainer, pl_module): # noqa
|
||||
...
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature())
|
||||
|
||||
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
|
||||
trainer.fit(model)
|
||||
|
||||
class OldSignatureModel(BoringModel):
|
||||
|
||||
def on_validation_epoch_end(self): # noqa
|
||||
...
|
||||
|
||||
model = OldSignatureModel()
|
||||
|
||||
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
|
||||
trainer.fit(model)
|
||||
|
||||
callback_warning_cache.clear()
|
||||
|
||||
class NewSignature(Callback):
|
||||
|
||||
def on_validation_epoch_end(self, trainer, pl_module, outputs):
|
||||
...
|
||||
|
||||
trainer.callbacks = [NewSignature()]
|
||||
with no_deprecated_call(match="`Callback.on_validation_epoch_end` signature has changed in v1.3."):
|
||||
trainer.fit(model)
|
||||
|
||||
class NewSignatureModel(BoringModel):
|
||||
|
||||
def on_validation_epoch_end(self, outputs):
|
||||
...
|
||||
|
||||
model = NewSignatureModel()
|
||||
with no_deprecated_call(match="`ModelHooks.on_validation_epoch_end` signature has changed in v1.3."):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_v1_5_0_old_on_test_epoch_end(tmpdir):
|
||||
callback_warning_cache.clear()
|
||||
|
||||
class OldSignature(Callback):
|
||||
|
||||
def on_test_epoch_end(self, trainer, pl_module): # noqa
|
||||
...
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature())
|
||||
|
||||
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
|
||||
trainer.test(model)
|
||||
|
||||
class OldSignatureModel(BoringModel):
|
||||
|
||||
def on_test_epoch_end(self): # noqa
|
||||
...
|
||||
|
||||
model = OldSignatureModel()
|
||||
|
||||
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
|
||||
trainer.test(model)
|
||||
|
||||
callback_warning_cache.clear()
|
||||
|
||||
class NewSignature(Callback):
|
||||
|
||||
def on_test_epoch_end(self, trainer, pl_module, outputs):
|
||||
...
|
||||
|
||||
trainer.callbacks = [NewSignature()]
|
||||
with no_deprecated_call(match="`Callback.on_test_epoch_end` signature has changed in v1.3."):
|
||||
trainer.test(model)
|
||||
|
||||
class NewSignatureModel(BoringModel):
|
||||
|
||||
def on_test_epoch_end(self, outputs):
|
||||
...
|
||||
|
||||
model = NewSignatureModel()
|
||||
with no_deprecated_call(match="`ModelHooks.on_test_epoch_end` signature has changed in v1.3."):
|
||||
trainer.test(model)
|
||||
|
|
|
@ -360,9 +360,9 @@ def test_trainer_model_hook_system(tmpdir):
|
|||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_validation_epoch_start()
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
def on_validation_epoch_end(self, outputs):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_validation_epoch_end()
|
||||
super().on_validation_epoch_end(outputs)
|
||||
|
||||
def on_test_start(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
|
@ -380,9 +380,9 @@ def test_trainer_model_hook_system(tmpdir):
|
|||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_test_epoch_start()
|
||||
|
||||
def on_test_epoch_end(self):
|
||||
def on_test_epoch_end(self, outputs):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_test_epoch_end()
|
||||
super().on_test_epoch_end(outputs)
|
||||
|
||||
def on_validation_model_eval(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
|
|
|
@ -126,7 +126,6 @@ def test__validation_step__step_end__epoch_end__log(tmpdir):
|
|||
def validation_epoch_end(self, outputs):
|
||||
self.log('g', torch.tensor(2, device=self.device), on_epoch=True)
|
||||
self.validation_epoch_end_called = True
|
||||
assert len(self.trainer.evaluation_loop.outputs) == 0
|
||||
|
||||
def backward(self, loss, optimizer, optimizer_idx):
|
||||
return LightningModule.backward(self, loss, optimizer, optimizer_idx)
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest import mock
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.trainer.evaluation_loop.EvaluationLoop.call_on_evaluation_epoch_end_hook")
|
||||
def test_call_on_evaluation_epoch_end_hook(eval_epoch_end_mock, tmpdir):
|
||||
"""
|
||||
Tests that `call_on_evaluation_epoch_end_hook` is called
|
||||
for `on_validation_epoch_end` and `on_test_epoch_end` hooks
|
||||
"""
|
||||
model = BoringModel()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
max_epochs=2,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
trainer.fit(model)
|
||||
# sanity + 2 epochs
|
||||
assert eval_epoch_end_mock.call_count == 3
|
||||
|
||||
trainer.test()
|
||||
# sanity + 2 epochs + called once for test
|
||||
assert eval_epoch_end_mock.call_count == 4
|
Loading…
Reference in New Issue