[2/2] Remove outputs from evaluation epoch end hooks (#7338)

* Remove outputs from on_train_epoch_end

* iterate

* Update callback_hook.py

* update

* early stop?

* fix

* Update pytorch_lightning/trainer/training_loop.py

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>

* Update trainer.py

* update

* Update training_loop.py

* early stop?

* fix

* Remove outputs from evaluation epoch end hooks

* update

* Update test_remove_1-5.py

* fix lints

* Update base.py

* rm-outputs

* Update evaluation_loop.py

* try-save-more-memory

* Update trainer.py

* Update trainer.py

* cache-at-start

* Update evaluation_loop.py

* Update training_loop.py

* Update training_loop.py

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
This commit is contained in:
ananthsub 2021-05-05 12:50:58 -07:00 committed by GitHub
parent fbcd63aa89
commit 7b45bcfedb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 39 additions and 252 deletions

View File

@ -88,9 +88,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for the PyTorch 1.8.1 autograd profiler ([#6618](https://github.com/PyTorchLightning/pytorch-lightning/pull/6618))
- Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120))
- Added `configure_sharded_model` hook ([#6679](https://github.com/PyTorchLightning/pytorch-lightning/pull/6679))
@ -213,6 +210,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323))
- Deprecated `outputs` in both `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#7339](https://github.com/PyTorchLightning/pytorch-lightning/pull/7339))
- Deprecated `LightningModule.grad_norm` in favor of `pytorch_lightning.utilities.grads.grad_norm` ([#7292](https://github.com/PyTorchLightning/pytorch-lightning/pull/7292))

View File

@ -22,7 +22,7 @@ from typing import Any, Dict, List, Optional
from torch.optim import Optimizer
import pytorch_lightning as pl
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.types import STEP_OUTPUT
class Callback(abc.ABC):
@ -108,9 +108,7 @@ class Callback(abc.ABC):
"""Called when the val epoch begins."""
pass
def on_validation_epoch_end(
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT
) -> None:
def on_validation_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the val epoch ends."""
pass
@ -118,7 +116,7 @@ class Callback(abc.ABC):
"""Called when the test epoch begins."""
pass
def on_test_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT) -> None:
def on_test_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
"""Called when the test epoch ends."""
pass

View File

@ -20,7 +20,7 @@ from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.types import STEP_OUTPUT
class ModelHooks:
@ -245,7 +245,7 @@ class ModelHooks:
Called in the validation loop at the very beginning of the epoch.
"""
def on_validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
def on_validation_epoch_end(self) -> None:
"""
Called in the validation loop at the very end of the epoch.
"""
@ -255,7 +255,7 @@ class ModelHooks:
Called in the test loop at the very beginning of the epoch.
"""
def on_test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
def on_test_epoch_end(self) -> None:
"""
Called in the test loop at the very end of the epoch.
"""

View File

@ -111,44 +111,20 @@ class TrainerCallbackHookMixin(ABC):
for callback in self.callbacks:
callback.on_validation_epoch_start(self, self.lightning_module)
def on_validation_epoch_end(self, outputs: EPOCH_OUTPUT):
"""Called when the epoch ends.
Args:
outputs: List of outputs on each ``validation`` epoch
"""
def on_validation_epoch_end(self):
"""Called when the validation epoch ends."""
for callback in self.callbacks:
if is_param_in_hook_signature(callback.on_validation_epoch_end, "outputs"):
callback.on_validation_epoch_end(self, self.lightning_module, outputs)
else:
warning_cache.warn(
"`Callback.on_validation_epoch_end` signature has changed in v1.3."
" `outputs` parameter has been added."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
callback.on_validation_epoch_end(self, self.lightning_module)
callback.on_validation_epoch_end(self, self.lightning_module)
def on_test_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_test_epoch_start(self, self.lightning_module)
def on_test_epoch_end(self, outputs: EPOCH_OUTPUT):
"""Called when the epoch ends.
Args:
outputs: List of outputs on each ``test`` epoch
"""
def on_test_epoch_end(self):
"""Called when the test epoch ends."""
for callback in self.callbacks:
if is_param_in_hook_signature(callback.on_test_epoch_end, "outputs"):
callback.on_test_epoch_end(self, self.lightning_module, outputs)
else:
warning_cache.warn(
"`Callback.on_test_epoch_end` signature has changed in v1.3."
" `outputs` parameter has been added."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
callback.on_test_epoch_end(self, self.lightning_module)
callback.on_test_epoch_end(self, self.lightning_module)
def on_predict_epoch_start(self) -> None:
"""Called when the epoch begins."""

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union
from torch.utils.data import DataLoader
@ -20,7 +20,6 @@ from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache
@ -76,6 +75,7 @@ class EvaluationLoop(object):
return sum(max_batches) == 0
def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end()
if self.trainer.testing:
self.trainer.call_hook('on_test_start', *args, **kwargs)
else:
@ -188,6 +188,13 @@ class EvaluationLoop(object):
output = self.trainer.call_hook('validation_step_end', *args, **kwargs)
return output
def _should_track_batch_outputs_for_epoch_end(self) -> bool:
model = self.trainer.lightning_module
if self.trainer.testing:
return is_overridden('test_epoch_end', model=model)
else:
return is_overridden('validation_epoch_end', model=model)
def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
# unset dataloder_idx in model
self.trainer.logger_connector.evaluation_epoch_end()
@ -241,7 +248,7 @@ class EvaluationLoop(object):
# track debug metrics
self.trainer.dev_debugger.track_eval_loss_history(batch_idx, dataloader_idx, output)
def on_evaluation_epoch_end(self, outputs: Union[List[List[Dict]], List[Dict]]) -> None:
def on_evaluation_epoch_end(self) -> None:
model_ref = self.trainer.lightning_module
hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"
@ -251,18 +258,11 @@ class EvaluationLoop(object):
if hasattr(self.trainer, hook_name):
on_evaluation_epoch_end_hook = getattr(self.trainer, hook_name)
on_evaluation_epoch_end_hook(outputs)
on_evaluation_epoch_end_hook()
if is_overridden(hook_name, model_ref):
model_hook_fx = getattr(model_ref, hook_name)
if is_param_in_hook_signature(model_hook_fx, "outputs"):
model_hook_fx(outputs)
else:
self.warning_cache.warn(
f"`ModelHooks.{hook_name}` signature has changed in v1.3. `outputs` parameter has been added."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
model_hook_fx()
model_hook_fx()
self.trainer._cache_logged_metrics()

View File

@ -972,7 +972,8 @@ class Trainer(
dl_outputs = self.track_output_for_epoch_end(dl_outputs, output)
# store batch level output per dataloader
self.evaluation_loop.outputs.append(dl_outputs)
if self.evaluation_loop.should_track_batch_outputs_for_epoch_end:
self.evaluation_loop.outputs.append(dl_outputs)
outputs = self.evaluation_loop.outputs
@ -980,14 +981,14 @@ class Trainer(
self.evaluation_loop.outputs = []
# with a single dataloader don't pass a 2D list
if self.evaluation_loop.num_dataloaders == 1:
if len(outputs) > 0 and self.evaluation_loop.num_dataloaders == 1:
outputs = outputs[0]
# lightning module method
self.evaluation_loop.evaluation_epoch_end(outputs)
# hook
self.evaluation_loop.on_evaluation_epoch_end(outputs)
self.evaluation_loop.on_evaluation_epoch_end()
# update epoch-level lr_schedulers
if on_epoch:
@ -1212,8 +1213,8 @@ class Trainer(
def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
# Note this implementation is copy/pasted into the TrainLoop class in TrainLoop._on_train_epoch_end_hook
# This was done to manage the deprecation of an argument to on_train_epoch_end
# If making chnages to this function, ensure that those changes are also made to
# This was done to manage the deprecation of the `outputs` argument to on_train_epoch_end
# If making changes to this function, ensure that those changes are also made to
# TrainLoop._on_train_epoch_end_hook
# set hook_name to model + reset Result obj

View File

@ -65,48 +65,6 @@ def test_train_step_no_return(tmpdir, single_cb: bool):
trainer.fit(model)
def test_on_val_epoch_end_outputs(tmpdir):
class CB(Callback):
def on_validation_epoch_end(self, trainer, pl_module, outputs):
if trainer.running_sanity_check:
assert len(outputs) == trainer.num_sanity_val_batches[0]
else:
assert len(outputs) == trainer.num_val_batches[0]
model = BoringModel()
trainer = Trainer(
callbacks=CB(),
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
weights_summary=None,
)
trainer.fit(model)
def test_on_test_epoch_end_outputs(tmpdir):
class CB(Callback):
def on_test_epoch_end(self, trainer, pl_module, outputs):
assert len(outputs) == trainer.num_test_batches[0]
model = BoringModel()
trainer = Trainer(
callbacks=CB(),
default_root_dir=tmpdir,
weights_summary=None,
)
trainer.test(model)
def test_free_memory_on_eval_outputs(tmpdir):
class CB(Callback):

View File

@ -58,7 +58,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_epoch_end(trainer, model, ANY),
call.on_validation_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_sanity_check_end(trainer, model),
@ -90,7 +90,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_epoch_end(trainer, model, ANY),
call.on_validation_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
@ -128,7 +128,7 @@ def test_trainer_callback_hook_system_test(tmpdir):
call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_test_batch_start(trainer, model, ANY, 1, 0),
call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0),
call.on_test_epoch_end(trainer, model, ANY),
call.on_test_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_test_end(trainer, model),
call.teardown(trainer, model, 'test'),
@ -163,7 +163,7 @@ def test_trainer_callback_hook_system_validate(tmpdir):
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_batch_start(trainer, model, ANY, 1, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0),
call.on_validation_epoch_end(trainer, model, ANY),
call.on_validation_epoch_end(trainer, model),
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.teardown(trainer, model, 'validate'),

View File

@ -1,56 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Trainer
from tests.helpers.boring_model import BoringModel
def test_on_val_epoch_end_outputs(tmpdir):
class TestModel(BoringModel):
def on_validation_epoch_end(self, outputs):
if trainer.running_sanity_check:
assert len(outputs) == trainer.num_sanity_val_batches[0]
else:
assert len(outputs) == trainer.num_val_batches[0]
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
weights_summary=None,
)
trainer.fit(model)
def test_on_test_epoch_end_outputs(tmpdir):
class TestModel(BoringModel):
def on_test_epoch_end(self, outputs):
assert len(outputs) == trainer.num_test_batches[0]
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=2,
weights_summary=None,
)
trainer.test(model)

View File

@ -263,96 +263,6 @@ def test_v1_5_0_old_on_train_epoch_end(tmpdir):
trainer.fit(model)
def test_v1_5_0_old_on_validation_epoch_end(tmpdir):
callback_warning_cache.clear()
class OldSignature(Callback):
def on_validation_epoch_end(self, trainer, pl_module): # noqa
...
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature())
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.fit(model)
class OldSignatureModel(BoringModel):
def on_validation_epoch_end(self): # noqa
...
model = OldSignatureModel()
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.fit(model)
callback_warning_cache.clear()
class NewSignature(Callback):
def on_validation_epoch_end(self, trainer, pl_module, outputs):
...
trainer.callbacks = [NewSignature()]
with no_deprecated_call(match="`Callback.on_validation_epoch_end` signature has changed in v1.3."):
trainer.fit(model)
class NewSignatureModel(BoringModel):
def on_validation_epoch_end(self, outputs):
...
model = NewSignatureModel()
with no_deprecated_call(match="`ModelHooks.on_validation_epoch_end` signature has changed in v1.3."):
trainer.fit(model)
def test_v1_5_0_old_on_test_epoch_end(tmpdir):
callback_warning_cache.clear()
class OldSignature(Callback):
def on_test_epoch_end(self, trainer, pl_module): # noqa
...
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature())
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.test(model)
class OldSignatureModel(BoringModel):
def on_test_epoch_end(self): # noqa
...
model = OldSignatureModel()
with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.test(model)
callback_warning_cache.clear()
class NewSignature(Callback):
def on_test_epoch_end(self, trainer, pl_module, outputs):
...
trainer.callbacks = [NewSignature()]
with no_deprecated_call(match="`Callback.on_test_epoch_end` signature has changed in v1.3."):
trainer.test(model)
class NewSignatureModel(BoringModel):
def on_test_epoch_end(self, outputs):
...
model = NewSignatureModel()
with no_deprecated_call(match="`ModelHooks.on_test_epoch_end` signature has changed in v1.3."):
trainer.test(model)
@pytest.mark.parametrize("cls", (BaseProfiler, SimpleProfiler, AdvancedProfiler, PyTorchProfiler))
def test_v1_5_0_profiler_output_filename(tmpdir, cls):
filepath = str(tmpdir / "test.txt")

View File

@ -681,10 +681,10 @@ def test_metrics_reset(tmpdir):
def on_train_epoch_end(self):
self._assert_epoch_end('train')
def on_validation_epoch_end(self, outputs):
def on_validation_epoch_end(self):
self._assert_epoch_end('val')
def on_test_epoch_end(self, outputs):
def on_test_epoch_end(self):
self._assert_epoch_end('test')
def _assert_called(model, stage):