lightning/pytorch_lightning/loops/dataloader/evaluation_loop.py

276 lines
10 KiB
Python
Raw Normal View History

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Sequence, Union
from deprecate.utils import void
from torch.utils.data.dataloader import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import EpochLoopProgress
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
class EvaluationLoop(DataLoaderLoop):
"""Loops over all dataloaders for evaluation."""
def __init__(self):
super().__init__()
self.outputs = []
self.progress = EpochLoopProgress()
self.epoch_loop = EvaluationEpochLoop()
self._results = ResultCollection(training=False)
self._max_batches: Optional[Union[int, Sequence[int]]] = None
self._has_run: bool = False
@property
def num_dataloaders(self) -> int:
"""Returns the total number of dataloaders"""
# case where user does:
# return dl1, dl2
dataloaders = self.dataloaders
if dataloaders is None:
return 0
length = len(dataloaders)
if length > 0 and isinstance(dataloaders[0], (list, tuple)):
length = len(dataloaders[0])
return length
@property
def dataloaders(self) -> Sequence[DataLoader]:
"""Returns the validation or test dataloaders"""
if self.trainer.testing:
return self.trainer.test_dataloaders
return self.trainer.val_dataloaders
@property
def predictions(self):
"""Returns the predictions from all dataloaders"""
return self.epoch_loop.predictions
def connect(
self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any
) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
if progress is not None:
self.progress = progress
self.epoch_loop.connect(trainer, progress=self.progress.epoch)
@property
def done(self) -> bool:
"""Returns whether all dataloaders are processed or evaluation should be skipped altogether"""
return (self.current_dataloader_idx >= len(self.dataloaders)) or self.skip
@property
def skip(self) -> bool:
"""Returns whether the evaluation should be skipped."""
max_batches = self.get_max_batches()
return sum(max_batches) == 0
def reset(self) -> None:
"""Resets the internal state of the loop"""
self.iteration_count = 0
self._max_batches = self.get_max_batches()
# bookkeeping
self.outputs = []
if isinstance(self._max_batches, int):
self._max_batches = [self._max_batches] * len(self.dataloaders)
def on_skip(self) -> List:
return []
def on_run_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs the ``on_evaluation_model_eval``, ``on_evaluation_start`` and ``on_evaluation_epoch_start`` hooks"""
void(*args, **kwargs)
# hook
self.on_evaluation_model_eval()
self.trainer.lightning_module.zero_grad()
self.on_evaluation_start()
self.on_evaluation_epoch_start()
def advance(self, *args: Any, **kwargs: Any) -> None:
"""Performs evaluation on one single dataloader"""
void(*args, **kwargs)
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
dataloader_iter = enumerate(dataloader)
dl_max_batches = self._max_batches[self.current_dataloader_idx]
dl_outputs = self.epoch_loop.run(
dataloader_iter,
self.current_dataloader_idx,
dl_max_batches,
self.num_dataloaders,
)
# store batch level output per dataloader
if self.should_track_batch_outputs_for_epoch_end:
self.outputs.append(dl_outputs)
if not self.trainer.sanity_checking:
# indicate the loop has run
self._has_run = True
def on_run_end(self) -> Any:
"""Runs the ``on_evaluation_epoch_end`` hook"""
outputs = self.outputs
# free memory
self.outputs = []
# with a single dataloader don't pass a 2D list
if len(outputs) > 0 and self.num_dataloaders == 1:
outputs = outputs[0]
# lightning module method
self.evaluation_epoch_end(outputs)
# hook
self.on_evaluation_epoch_end()
# log epoch metrics
eval_loop_results = self.trainer.logger_connector.update_eval_epoch_metrics()
# hook
self.on_evaluation_end()
# save predictions to disk
self.epoch_loop.predictions.to_disk()
# enable train mode again
self.on_evaluation_model_train()
return eval_loop_results
def get_max_batches(self) -> List[Union[int, float]]:
"""Returns the max number of batches for each dataloader"""
if self.trainer.testing:
max_batches = self.trainer.num_test_batches
else:
if self.trainer.sanity_checking:
self.trainer.num_sanity_val_batches = [
min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches
]
max_batches = self.trainer.num_sanity_val_batches
else:
max_batches = self.trainer.num_val_batches
return max_batches
def reload_evaluation_dataloaders(self) -> None:
"""Reloads dataloaders if necessary"""
model = self.trainer.lightning_module
if self.trainer.testing:
self.trainer.reset_test_dataloader(model)
Enables reload of dataloaders on every n epochs from every epoch (#5043) * edit arg to reload_dataloaders_every_n_epoch * init reload_dataloaders_every_n_epoch * edit logic to reload dl * update arg to test datamodule * update arg test dataloader * edit reload dl logic in eval loop * fix var name in reset_train_val_dataloaders * fix error, use current_epoch attribute * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * assert reload_dataloaders_every_n_epochs positive * assert reload_dataloaders_every_n_epochs positive * add trainer property should reload dl * update should reload dl in train loop * condition on should reload dl in eval loop * pep8 * fix update should reload dl in train loop * add test case * replace assertion with misconfig exception * remove unused variable * remove unnecessary checks * replace to BoringModel * remove unrequired comment * deprecate _every_epoch * add deprecated argument to trainer * test case for deprecated arg * remove unrequired assertion in train loop Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * modify misconfig exception for int Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * conv bool to int of depreciated _every_epoch Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * update description of deprecated param Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * update deprecation warning Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * modify argument to int only * fix deprecated test function name Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * merge tests for reload dls * add propery should reload dl * removed and added to trainer property * use property in train loop * remove deprecated test * add deprecated test to new file * test case for exception * update test datamodule every_n_epochs * update trainer docs * update hooks with every_n_epochs * edit format if statement Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update CHANGELOG.md * Apply suggestions from code review Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * typo in exception * pytest check only misconfig exception * remove unnecessary code in test * remove unnecessary code in deprec test * added match in test * typo in comment * revert to prev, keep only req in context manager * Apply suggestions from code review * docs * rebase * Apply suggestions from code review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import: model_helpers instead of model_utils * fix, add reload_dataloaders_every_n_epochs argument to data connector * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add required imports * move deprecated log * add missing import rank_zero_warn * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update varname in should_reload_dl_epoch suggestion from code review * Fix CHANGELOG. Update deprecation versions * Minor change * change property name, mark protected * update property name * update property name * Remove deprecated *_loop.py files * Rename test func * Update CHANGELOG.md * use rank_zero_deprecation * update deprecation message in trainer api docs * test deprecation with real arg name in message * fix typo in trainer docs Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta <nitta@akihironitta.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2021-07-07 11:10:08 +00:00
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch:
self.trainer.reset_val_dataloader(model)
def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_start`` hooks"""
self.should_track_batch_outputs_for_epoch_end: bool = self._should_track_batch_outputs_for_epoch_end()
assert self._results is not None
self._results.to(device=self.trainer.lightning_module.device)
if self.trainer.testing:
self.trainer.call_hook("on_test_start", *args, **kwargs)
else:
self.trainer.call_hook("on_validation_start", *args, **kwargs)
def on_evaluation_model_eval(self) -> None:
"""Sets model to eval mode"""
model_ref = self.trainer.lightning_module
if self.trainer.testing:
model_ref.on_test_model_eval()
else:
model_ref.on_validation_model_eval()
def on_evaluation_model_train(self) -> None:
"""Sets model to train mode"""
model_ref = self.trainer.lightning_module
if self.trainer.testing:
model_ref.on_test_model_train()
else:
model_ref.on_validation_model_train()
def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_end`` hook"""
if self.trainer.testing:
self.trainer.call_hook("on_test_end", *args, **kwargs)
else:
self.trainer.call_hook("on_validation_end", *args, **kwargs)
if self.trainer.state.fn != TrainerFn.FITTING:
# summarize profile results
self.trainer.profiler.describe()
# reset any `torchmetrics.Metric` and the logger connector state
self.trainer.logger_connector.reset(metrics=True)
def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks"""
self.trainer.logger_connector.on_epoch_start()
self.trainer.call_hook("on_epoch_start", *args, **kwargs)
if self.trainer.testing:
self.trainer.call_hook("on_test_epoch_start", *args, **kwargs)
else:
self.trainer.call_hook("on_validation_epoch_start", *args, **kwargs)
def _should_track_batch_outputs_for_epoch_end(self) -> bool:
"""Whether the batch outputs should be stored for later usage"""
model = self.trainer.lightning_module
if self.trainer.testing:
return is_overridden("test_epoch_end", model)
return is_overridden("validation_epoch_end", model)
def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
"""Runs ``{validation/test}_epoch_end``"""
# inform logger the batch loop has finished
self.trainer.logger_connector.epoch_end_reached()
# call the model epoch end
model = self.trainer.lightning_module
# unset dataloader_idx in model
model._current_dataloader_idx = None
if self.trainer.testing:
if is_overridden("test_epoch_end", model):
model._current_fx_name = "test_epoch_end"
model.test_epoch_end(outputs)
else:
if is_overridden("validation_epoch_end", model):
model._current_fx_name = "validation_epoch_end"
model.validation_epoch_end(outputs)
def on_evaluation_epoch_end(self) -> None:
"""Runs ``on_{validation/test}_epoch_end`` hook"""
hook_name = ("on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end")
self.trainer.call_hook(hook_name)
self.trainer.call_hook("on_epoch_end")
self.trainer.logger_connector.on_epoch_end()
def teardown(self) -> None:
self._results.cpu()
self.epoch_loop.teardown()