lightning/pytorch_lightning/loops/dataloader/evaluation_loop.py

247 lines
9.9 KiB
Python
Raw Normal View History

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Sequence, Union
from deprecate.utils import void
from torch.utils.data.dataloader import DataLoader
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
class EvaluationLoop(DataLoaderLoop):
"""Loops over all dataloaders for evaluation."""
def __init__(self) -> None:
super().__init__()
self.epoch_loop = EvaluationEpochLoop()
self._results = ResultCollection(training=False)
self._outputs: List[EPOCH_OUTPUT] = []
self._max_batches: List[int] = []
self._has_run: bool = False
@property
def num_dataloaders(self) -> int:
"""Returns the total number of dataloaders."""
# case where user does:
# return dl1, dl2
dataloaders = self.dataloaders
if dataloaders is None:
return 0
length = len(dataloaders)
if length > 0 and isinstance(dataloaders[0], (list, tuple)):
length = len(dataloaders[0])
return length
@property
def dataloaders(self) -> Sequence[DataLoader]:
"""Returns the validation or test dataloaders."""
dataloaders = self.trainer.test_dataloaders if self.trainer.testing else self.trainer.val_dataloaders
if dataloaders is None:
raise RuntimeError("Dataloaders should be available.")
return dataloaders
def connect(self, epoch_loop: EvaluationEpochLoop) -> None: # type: ignore[override]
"""Connect the evaluation epoch loop with this loop."""
self.epoch_loop = epoch_loop
@property
def done(self) -> bool:
"""Returns whether all dataloaders are processed or evaluation should be skipped altogether."""
Add progress tracking on Loops - 2/n (#8362) * resolve issues * update * update * update * add more exceptions * resolve bug * update * update * update changelog * resolve bug * resolve comments * update * update * update changelog * update * update * remove space * update * add progress tracking to loops * validate json * update * convert to dict for better readability * validate reload * update * update * update on comments * remove deadcode * clean changelog * clean changelog * update * update on comments * CHANGELOG * CHANGELOG * Update pytorch_lightning/loops/base.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * whitespace suggestions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make fault_tolerant_enabled protected * whitespace fixes around Args * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * typo it's -> its * fix copy-paste typo in progress docstring * Delete classes * Minor change * docs * protected get_loops_state * merge restore_loops with restore_progress * Fix tests after removals * explicit save with trainer.save_checkpoint() * handle optimization restart based on optimizer_idx * update increments * update val batch progress and remove iteration count * update progress tracking for dataloader loops * remove self.dataloader_idx from eval_epoch_loop * add batch progress to predict loop * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * incorporate progress tracking for current_epoch * Fix test * Actually remove it * Remove unused TrainingEpochProgress * Fix optimization progress - missing scheduler * Restarting changes * Scheduler progress * Unused property, reset on epoch * Resolve FIXME * Remove FIXME * fix test_progress (wip) * fix batch_progress.current.reset * Hold off on split progress. Out of scope of this PR * Unnecessary if * fix structure in test_progress * structure * clean up unused variables in test_progress * refactor naming and organization in test_progress * Unnecessary variable * Remove unnecessary diff * Improve comment * Undo typing change to avoid polluting everything with mypy fixes * Fix and improve test_loops.py * Fix and organize `test_loop_state_dict` * Remove unnecessary checks in test * Update test after disallowing updates on None attributes * Typing * Minor test cleanup * Fix and move loop test * Move test from progress to loops * Reset the scheduler progress * SchedulerProgress fix * Consistent whitespace * Fix final test * Minor test changes * One test to rule them all * Formatting * Rename and clean variables * Shorter names * Shorter scheduler name * Fix optimizer step calculation for stop_batch=2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove empty connects * Update CHANGELOG * Holy shit finally got the formula right * Fix final thing!!! * Do not check state dicts * parametrize multiple_dataloader progress test * Update CHANGELOG.md Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <justus.schock@posteo.de>
2021-07-19 08:31:45 +00:00
return super().done or self.skip
@property
def skip(self) -> bool:
"""Returns whether the evaluation should be skipped."""
max_batches = self._get_max_batches()
return sum(max_batches) == 0
def reset(self) -> None:
"""Resets the internal state of the loop."""
self._max_batches = self._get_max_batches()
# bookkeeping
self._outputs = []
if isinstance(self._max_batches, int):
self._max_batches = [self._max_batches] * len(self.dataloaders)
Add progress tracking on Loops - 2/n (#8362) * resolve issues * update * update * update * add more exceptions * resolve bug * update * update * update changelog * resolve bug * resolve comments * update * update * update changelog * update * update * remove space * update * add progress tracking to loops * validate json * update * convert to dict for better readability * validate reload * update * update * update on comments * remove deadcode * clean changelog * clean changelog * update * update on comments * CHANGELOG * CHANGELOG * Update pytorch_lightning/loops/base.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * whitespace suggestions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make fault_tolerant_enabled protected * whitespace fixes around Args * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * typo it's -> its * fix copy-paste typo in progress docstring * Delete classes * Minor change * docs * protected get_loops_state * merge restore_loops with restore_progress * Fix tests after removals * explicit save with trainer.save_checkpoint() * handle optimization restart based on optimizer_idx * update increments * update val batch progress and remove iteration count * update progress tracking for dataloader loops * remove self.dataloader_idx from eval_epoch_loop * add batch progress to predict loop * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * incorporate progress tracking for current_epoch * Fix test * Actually remove it * Remove unused TrainingEpochProgress * Fix optimization progress - missing scheduler * Restarting changes * Scheduler progress * Unused property, reset on epoch * Resolve FIXME * Remove FIXME * fix test_progress (wip) * fix batch_progress.current.reset * Hold off on split progress. Out of scope of this PR * Unnecessary if * fix structure in test_progress * structure * clean up unused variables in test_progress * refactor naming and organization in test_progress * Unnecessary variable * Remove unnecessary diff * Improve comment * Undo typing change to avoid polluting everything with mypy fixes * Fix and improve test_loops.py * Fix and organize `test_loop_state_dict` * Remove unnecessary checks in test * Update test after disallowing updates on None attributes * Typing * Minor test cleanup * Fix and move loop test * Move test from progress to loops * Reset the scheduler progress * SchedulerProgress fix * Consistent whitespace * Fix final test * Minor test changes * One test to rule them all * Formatting * Rename and clean variables * Shorter names * Shorter scheduler name * Fix optimizer step calculation for stop_batch=2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove empty connects * Update CHANGELOG * Holy shit finally got the formula right * Fix final thing!!! * Do not check state dicts * parametrize multiple_dataloader progress test * Update CHANGELOG.md Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <justus.schock@posteo.de>
2021-07-19 08:31:45 +00:00
super().reset()
def on_skip(self) -> List:
return []
def on_run_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
hooks."""
void(*args, **kwargs)
# hook
self._on_evaluation_model_eval()
self.trainer.lightning_module.zero_grad()
self._on_evaluation_start()
self._on_evaluation_epoch_start()
def advance(self, *args: Any, **kwargs: Any) -> None:
"""Performs evaluation on one single dataloader."""
void(*args, **kwargs)
dataloader_idx = self.current_dataloader_idx
dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader)
2021-10-29 16:29:44 +00:00
self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader(
dataloader, dataloader_idx=dataloader_idx
)
dl_max_batches = self._max_batches[dataloader_idx]
dl_outputs = self.epoch_loop.run(
dataloader, dataloader_idx if self.num_dataloaders > 1 else None, dl_max_batches
)
# store batch level output per dataloader
self._outputs.append(dl_outputs)
if not self.trainer.sanity_checking:
# indicate the loop has run
self._has_run = True
def on_run_end(self) -> List[_OUT_DICT]:
"""Runs the ``_on_evaluation_epoch_end`` hook."""
outputs, self._outputs = self._outputs, [] # free memory
# lightning module method
self._evaluation_epoch_end(outputs)
# hook
self._on_evaluation_epoch_end()
# log epoch metrics
eval_loop_results = self.trainer.logger_connector.update_eval_epoch_metrics()
# hook
self._on_evaluation_end()
# enable train mode again
self._on_evaluation_model_train()
return eval_loop_results
def teardown(self) -> None:
self._results.cpu()
self.epoch_loop.teardown()
def _get_max_batches(self) -> List[int]:
"""Returns the max number of batches for each dataloader."""
if self.trainer.testing:
max_batches = self.trainer.num_test_batches
else:
if self.trainer.sanity_checking:
self.trainer.num_sanity_val_batches = [
min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches
]
max_batches = self.trainer.num_sanity_val_batches
else:
max_batches = self.trainer.num_val_batches
return max_batches
def _reload_evaluation_dataloaders(self) -> None:
"""Reloads dataloaders if necessary."""
if self.trainer.testing:
self.trainer.reset_test_dataloader()
Enables reload of dataloaders on every n epochs from every epoch (#5043) * edit arg to reload_dataloaders_every_n_epoch * init reload_dataloaders_every_n_epoch * edit logic to reload dl * update arg to test datamodule * update arg test dataloader * edit reload dl logic in eval loop * fix var name in reset_train_val_dataloaders * fix error, use current_epoch attribute * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * assert reload_dataloaders_every_n_epochs positive * assert reload_dataloaders_every_n_epochs positive * add trainer property should reload dl * update should reload dl in train loop * condition on should reload dl in eval loop * pep8 * fix update should reload dl in train loop * add test case * replace assertion with misconfig exception * remove unused variable * remove unnecessary checks * replace to BoringModel * remove unrequired comment * deprecate _every_epoch * add deprecated argument to trainer * test case for deprecated arg * remove unrequired assertion in train loop Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * modify misconfig exception for int Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * conv bool to int of depreciated _every_epoch Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * update description of deprecated param Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * update deprecation warning Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * modify argument to int only * fix deprecated test function name Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * merge tests for reload dls * add propery should reload dl * removed and added to trainer property * use property in train loop * remove deprecated test * add deprecated test to new file * test case for exception * update test datamodule every_n_epochs * update trainer docs * update hooks with every_n_epochs * edit format if statement Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update CHANGELOG.md * Apply suggestions from code review Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * typo in exception * pytest check only misconfig exception * remove unnecessary code in test * remove unnecessary code in deprec test * added match in test * typo in comment * revert to prev, keep only req in context manager * Apply suggestions from code review * docs * rebase * Apply suggestions from code review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import: model_helpers instead of model_utils * fix, add reload_dataloaders_every_n_epochs argument to data connector * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add required imports * move deprecated log * add missing import rank_zero_warn * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update varname in should_reload_dl_epoch suggestion from code review * Fix CHANGELOG. Update deprecation versions * Minor change * change property name, mark protected * update property name * update property name * Remove deprecated *_loop.py files * Rename test func * Update CHANGELOG.md * use rank_zero_deprecation * update deprecation message in trainer api docs * test deprecation with real arg name in message * fix typo in trainer docs Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta <nitta@akihironitta.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2021-07-07 11:10:08 +00:00
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch:
self.trainer.reset_val_dataloader()
def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_start`` hooks."""
assert self._results is not None
self._results.to(device=self.trainer.lightning_module.device)
if self.trainer.testing:
self.trainer._call_callback_hooks("on_test_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_test_start", *args, **kwargs)
self.trainer._call_ttp_hook("on_test_start", *args, **kwargs)
else:
self.trainer._call_callback_hooks("on_validation_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_validation_start", *args, **kwargs)
self.trainer._call_ttp_hook("on_validation_start", *args, **kwargs)
def _on_evaluation_model_eval(self) -> None:
"""Sets model to eval mode."""
if self.trainer.testing:
self.trainer._call_lightning_module_hook("on_test_model_eval")
else:
self.trainer._call_lightning_module_hook("on_validation_model_eval")
def _on_evaluation_model_train(self) -> None:
"""Sets model to train mode."""
model_ref = self.trainer.lightning_module
if self.trainer.testing:
model_ref.on_test_model_train()
else:
model_ref.on_validation_model_train()
def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_end`` hook."""
if self.trainer.testing:
self.trainer._call_callback_hooks("on_test_end", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_test_end", *args, **kwargs)
self.trainer._call_ttp_hook("on_test_end", *args, **kwargs)
else:
self.trainer._call_callback_hooks("on_validation_end", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_validation_end", *args, **kwargs)
self.trainer._call_ttp_hook("on_validation_end", *args, **kwargs)
# reset the logger connector state
self.trainer.logger_connector.reset_results()
def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks."""
self.trainer.logger_connector.on_epoch_start()
self.trainer._call_callback_hooks("on_epoch_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_epoch_start", *args, **kwargs)
if self.trainer.testing:
self.trainer._call_callback_hooks("on_test_epoch_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_test_epoch_start", *args, **kwargs)
else:
self.trainer._call_callback_hooks("on_validation_epoch_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_validation_epoch_start", *args, **kwargs)
def _evaluation_epoch_end(self, outputs: List[EPOCH_OUTPUT]) -> None:
"""Runs ``{validation/test}_epoch_end``"""
# inform logger the batch loop has finished
self.trainer.logger_connector.epoch_end_reached()
# with a single dataloader don't pass a 2D list
output_or_outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] = (
outputs[0] if len(outputs) > 0 and self.num_dataloaders == 1 else outputs
)
# call the model epoch end
if self.trainer.testing:
self.trainer._call_lightning_module_hook("test_epoch_end", output_or_outputs)
else:
self.trainer._call_lightning_module_hook("validation_epoch_end", output_or_outputs)
def _on_evaluation_epoch_end(self) -> None:
"""Runs ``on_{validation/test}_epoch_end`` hook."""
hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"
self.trainer._call_callback_hooks(hook_name)
self.trainer._call_lightning_module_hook(hook_name)
self.trainer._call_callback_hooks("on_epoch_end")
self.trainer._call_lightning_module_hook("on_epoch_end")
self.trainer.logger_connector.on_epoch_end()