lightning/pytorch_lightning/loops/dataloader/evaluation_loop.py

289 lines
12 KiB
Python
Raw Normal View History

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Sequence, Union
import torch
from deprecate.utils import void
from torch.utils.data.dataloader import DataLoader
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
class EvaluationLoop(DataLoaderLoop):
"""Loops over all dataloaders for evaluation."""
def __init__(self, verbose: bool = True) -> None:
super().__init__()
self.epoch_loop = EvaluationEpochLoop()
self.verbose = verbose
self._results = _ResultCollection(training=False)
self._outputs: List[EPOCH_OUTPUT] = []
self._logged_outputs: List[_OUT_DICT] = []
self._max_batches: List[int] = []
self._has_run: bool = False
@property
def num_dataloaders(self) -> int:
"""Returns the total number of dataloaders."""
# case where user does:
# return dl1, dl2
dataloaders = self.dataloaders
if dataloaders is None:
return 0
length = len(dataloaders)
if length > 0 and isinstance(dataloaders[0], (list, tuple)):
length = len(dataloaders[0])
return length
@property
def dataloaders(self) -> Sequence[DataLoader]:
"""Returns the validation or test dataloaders."""
dataloaders = self.trainer.test_dataloaders if self.trainer.testing else self.trainer.val_dataloaders
if dataloaders is None:
raise RuntimeError("Dataloaders should be available.")
return dataloaders
def connect(self, epoch_loop: EvaluationEpochLoop) -> None: # type: ignore[override]
"""Connect the evaluation epoch loop with this loop."""
self.epoch_loop = epoch_loop
@property
def done(self) -> bool:
"""Returns whether all dataloaders are processed or evaluation should be skipped altogether."""
Add progress tracking on Loops - 2/n (#8362) * resolve issues * update * update * update * add more exceptions * resolve bug * update * update * update changelog * resolve bug * resolve comments * update * update * update changelog * update * update * remove space * update * add progress tracking to loops * validate json * update * convert to dict for better readability * validate reload * update * update * update on comments * remove deadcode * clean changelog * clean changelog * update * update on comments * CHANGELOG * CHANGELOG * Update pytorch_lightning/loops/base.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * whitespace suggestions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make fault_tolerant_enabled protected * whitespace fixes around Args * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * typo it's -> its * fix copy-paste typo in progress docstring * Delete classes * Minor change * docs * protected get_loops_state * merge restore_loops with restore_progress * Fix tests after removals * explicit save with trainer.save_checkpoint() * handle optimization restart based on optimizer_idx * update increments * update val batch progress and remove iteration count * update progress tracking for dataloader loops * remove self.dataloader_idx from eval_epoch_loop * add batch progress to predict loop * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * incorporate progress tracking for current_epoch * Fix test * Actually remove it * Remove unused TrainingEpochProgress * Fix optimization progress - missing scheduler * Restarting changes * Scheduler progress * Unused property, reset on epoch * Resolve FIXME * Remove FIXME * fix test_progress (wip) * fix batch_progress.current.reset * Hold off on split progress. Out of scope of this PR * Unnecessary if * fix structure in test_progress * structure * clean up unused variables in test_progress * refactor naming and organization in test_progress * Unnecessary variable * Remove unnecessary diff * Improve comment * Undo typing change to avoid polluting everything with mypy fixes * Fix and improve test_loops.py * Fix and organize `test_loop_state_dict` * Remove unnecessary checks in test * Update test after disallowing updates on None attributes * Typing * Minor test cleanup * Fix and move loop test * Move test from progress to loops * Reset the scheduler progress * SchedulerProgress fix * Consistent whitespace * Fix final test * Minor test changes * One test to rule them all * Formatting * Rename and clean variables * Shorter names * Shorter scheduler name * Fix optimizer step calculation for stop_batch=2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove empty connects * Update CHANGELOG * Holy shit finally got the formula right * Fix final thing!!! * Do not check state dicts * parametrize multiple_dataloader progress test * Update CHANGELOG.md Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <justus.schock@posteo.de>
2021-07-19 08:31:45 +00:00
return super().done or self.skip
@property
def skip(self) -> bool:
"""Returns whether the evaluation should be skipped."""
max_batches = self._get_max_batches()
return sum(max_batches) == 0
def reset(self) -> None:
"""Resets the internal state of the loop."""
self._max_batches = self._get_max_batches()
# bookkeeping
self._outputs = []
self._logged_outputs = []
if isinstance(self._max_batches, int):
self._max_batches = [self._max_batches] * len(self.dataloaders)
Add progress tracking on Loops - 2/n (#8362) * resolve issues * update * update * update * add more exceptions * resolve bug * update * update * update changelog * resolve bug * resolve comments * update * update * update changelog * update * update * remove space * update * add progress tracking to loops * validate json * update * convert to dict for better readability * validate reload * update * update * update on comments * remove deadcode * clean changelog * clean changelog * update * update on comments * CHANGELOG * CHANGELOG * Update pytorch_lightning/loops/base.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * whitespace suggestions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make fault_tolerant_enabled protected * whitespace fixes around Args * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * typo it's -> its * fix copy-paste typo in progress docstring * Delete classes * Minor change * docs * protected get_loops_state * merge restore_loops with restore_progress * Fix tests after removals * explicit save with trainer.save_checkpoint() * handle optimization restart based on optimizer_idx * update increments * update val batch progress and remove iteration count * update progress tracking for dataloader loops * remove self.dataloader_idx from eval_epoch_loop * add batch progress to predict loop * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * incorporate progress tracking for current_epoch * Fix test * Actually remove it * Remove unused TrainingEpochProgress * Fix optimization progress - missing scheduler * Restarting changes * Scheduler progress * Unused property, reset on epoch * Resolve FIXME * Remove FIXME * fix test_progress (wip) * fix batch_progress.current.reset * Hold off on split progress. Out of scope of this PR * Unnecessary if * fix structure in test_progress * structure * clean up unused variables in test_progress * refactor naming and organization in test_progress * Unnecessary variable * Remove unnecessary diff * Improve comment * Undo typing change to avoid polluting everything with mypy fixes * Fix and improve test_loops.py * Fix and organize `test_loop_state_dict` * Remove unnecessary checks in test * Update test after disallowing updates on None attributes * Typing * Minor test cleanup * Fix and move loop test * Move test from progress to loops * Reset the scheduler progress * SchedulerProgress fix * Consistent whitespace * Fix final test * Minor test changes * One test to rule them all * Formatting * Rename and clean variables * Shorter names * Shorter scheduler name * Fix optimizer step calculation for stop_batch=2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove empty connects * Update CHANGELOG * Holy shit finally got the formula right * Fix final thing!!! * Do not check state dicts * parametrize multiple_dataloader progress test * Update CHANGELOG.md Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <justus.schock@posteo.de>
2021-07-19 08:31:45 +00:00
super().reset()
# when restarting, if we are running `validate` or `test` twice, since there's no concept of `max_epochs` we
# need to reset the current state when the loop has finished running
if self.done and self.trainer.state.fn != TrainerFn.FITTING:
self.dataloader_progress.reset_on_run()
Add progress tracking on Loops - 2/n (#8362) * resolve issues * update * update * update * add more exceptions * resolve bug * update * update * update changelog * resolve bug * resolve comments * update * update * update changelog * update * update * remove space * update * add progress tracking to loops * validate json * update * convert to dict for better readability * validate reload * update * update * update on comments * remove deadcode * clean changelog * clean changelog * update * update on comments * CHANGELOG * CHANGELOG * Update pytorch_lightning/loops/base.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * whitespace suggestions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make fault_tolerant_enabled protected * whitespace fixes around Args * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * typo it's -> its * fix copy-paste typo in progress docstring * Delete classes * Minor change * docs * protected get_loops_state * merge restore_loops with restore_progress * Fix tests after removals * explicit save with trainer.save_checkpoint() * handle optimization restart based on optimizer_idx * update increments * update val batch progress and remove iteration count * update progress tracking for dataloader loops * remove self.dataloader_idx from eval_epoch_loop * add batch progress to predict loop * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * incorporate progress tracking for current_epoch * Fix test * Actually remove it * Remove unused TrainingEpochProgress * Fix optimization progress - missing scheduler * Restarting changes * Scheduler progress * Unused property, reset on epoch * Resolve FIXME * Remove FIXME * fix test_progress (wip) * fix batch_progress.current.reset * Hold off on split progress. Out of scope of this PR * Unnecessary if * fix structure in test_progress * structure * clean up unused variables in test_progress * refactor naming and organization in test_progress * Unnecessary variable * Remove unnecessary diff * Improve comment * Undo typing change to avoid polluting everything with mypy fixes * Fix and improve test_loops.py * Fix and organize `test_loop_state_dict` * Remove unnecessary checks in test * Update test after disallowing updates on None attributes * Typing * Minor test cleanup * Fix and move loop test * Move test from progress to loops * Reset the scheduler progress * SchedulerProgress fix * Consistent whitespace * Fix final test * Minor test changes * One test to rule them all * Formatting * Rename and clean variables * Shorter names * Shorter scheduler name * Fix optimizer step calculation for stop_batch=2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove empty connects * Update CHANGELOG * Holy shit finally got the formula right * Fix final thing!!! * Do not check state dicts * parametrize multiple_dataloader progress test * Update CHANGELOG.md Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <justus.schock@posteo.de>
2021-07-19 08:31:45 +00:00
def on_skip(self) -> List:
return []
def on_run_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
hooks."""
void(*args, **kwargs)
# hook
self._on_evaluation_model_eval()
self.trainer.lightning_module.zero_grad()
self._on_evaluation_start()
self._on_evaluation_epoch_start()
def advance(self, *args: Any, **kwargs: Any) -> None:
"""Performs evaluation on one single dataloader."""
void(*args, **kwargs)
dataloader_idx = self.current_dataloader_idx
dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader)
2021-10-29 16:29:44 +00:00
self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader(
dataloader, dataloader_idx=dataloader_idx
)
dl_max_batches = self._max_batches[dataloader_idx]
dl_outputs = self.epoch_loop.run(
dataloader, dataloader_idx if self.num_dataloaders > 1 else None, dl_max_batches
)
# store batch level output per dataloader
self._outputs.append(dl_outputs)
if not self.trainer.sanity_checking:
# indicate the loop has run
self._has_run = True
def on_advance_end(self) -> None:
self.trainer.logger_connector.epoch_end_reached()
self._logged_outputs.append(self.trainer.logger_connector.update_eval_epoch_metrics())
super().on_advance_end()
def on_run_end(self) -> List[_OUT_DICT]:
"""Runs the ``_on_evaluation_epoch_end`` hook."""
# if `done` returned True before any iterations were done, this won't have been called in `on_advance_end`
self.trainer.logger_connector.epoch_end_reached()
# hook
self._evaluation_epoch_end(self._outputs)
self._outputs = [] # free memory
# hook
self._on_evaluation_epoch_end()
logged_outputs, self._logged_outputs = self._logged_outputs, [] # free memory
# include any logged outputs on epoch_end
if self.num_dataloaders < 2: # TODO: remove this check
epoch_end_logged_outputs = self.trainer.logger_connector.update_eval_epoch_metrics()
for dl_outputs in logged_outputs:
dl_outputs.update(epoch_end_logged_outputs)
# log metrics
self.trainer.logger_connector.log_eval_end_metrics()
# hook
self._on_evaluation_end()
# enable train mode again
self._on_evaluation_model_train()
if self.verbose and self.trainer.is_global_zero:
assert self.trainer.state.stage is not None
self._print_results(logged_outputs, self.trainer.state.stage)
return logged_outputs
def teardown(self) -> None:
self._results.cpu()
self.epoch_loop.teardown()
def _get_max_batches(self) -> List[int]:
"""Returns the max number of batches for each dataloader."""
if self.trainer.testing:
max_batches = self.trainer.num_test_batches
else:
if self.trainer.sanity_checking:
self.trainer.num_sanity_val_batches = [
min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches
]
max_batches = self.trainer.num_sanity_val_batches
else:
max_batches = self.trainer.num_val_batches
return max_batches
def _reload_evaluation_dataloaders(self) -> None:
"""Reloads dataloaders if necessary."""
if self.trainer.testing:
self.trainer.reset_test_dataloader()
Enables reload of dataloaders on every n epochs from every epoch (#5043) * edit arg to reload_dataloaders_every_n_epoch * init reload_dataloaders_every_n_epoch * edit logic to reload dl * update arg to test datamodule * update arg test dataloader * edit reload dl logic in eval loop * fix var name in reset_train_val_dataloaders * fix error, use current_epoch attribute * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * edit every_n_epoch to every_n_epochs * assert reload_dataloaders_every_n_epochs positive * assert reload_dataloaders_every_n_epochs positive * add trainer property should reload dl * update should reload dl in train loop * condition on should reload dl in eval loop * pep8 * fix update should reload dl in train loop * add test case * replace assertion with misconfig exception * remove unused variable * remove unnecessary checks * replace to BoringModel * remove unrequired comment * deprecate _every_epoch * add deprecated argument to trainer * test case for deprecated arg * remove unrequired assertion in train loop Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * modify misconfig exception for int Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * conv bool to int of depreciated _every_epoch Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * update description of deprecated param Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * update deprecation warning Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * modify argument to int only * fix deprecated test function name Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * merge tests for reload dls * add propery should reload dl * removed and added to trainer property * use property in train loop * remove deprecated test * add deprecated test to new file * test case for exception * update test datamodule every_n_epochs * update trainer docs * update hooks with every_n_epochs * edit format if statement Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update CHANGELOG.md * Apply suggestions from code review Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * typo in exception * pytest check only misconfig exception * remove unnecessary code in test * remove unnecessary code in deprec test * added match in test * typo in comment * revert to prev, keep only req in context manager * Apply suggestions from code review * docs * rebase * Apply suggestions from code review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import: model_helpers instead of model_utils * fix, add reload_dataloaders_every_n_epochs argument to data connector * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add required imports * move deprecated log * add missing import rank_zero_warn * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update varname in should_reload_dl_epoch suggestion from code review * Fix CHANGELOG. Update deprecation versions * Minor change * change property name, mark protected * update property name * update property name * Remove deprecated *_loop.py files * Rename test func * Update CHANGELOG.md * use rank_zero_deprecation * update deprecation message in trainer api docs * test deprecation with real arg name in message * fix typo in trainer docs Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta <nitta@akihironitta.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2021-07-07 11:10:08 +00:00
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch:
self.trainer.reset_val_dataloader()
def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_start`` hooks."""
assert self._results is not None
self._results.to(device=self.trainer.lightning_module.device)
if self.trainer.testing:
self.trainer._call_callback_hooks("on_test_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_test_start", *args, **kwargs)
self.trainer._call_strategy_hook("on_test_start", *args, **kwargs)
else:
self.trainer._call_callback_hooks("on_validation_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_validation_start", *args, **kwargs)
self.trainer._call_strategy_hook("on_validation_start", *args, **kwargs)
def _on_evaluation_model_eval(self) -> None:
"""Sets model to eval mode."""
if self.trainer.testing:
self.trainer._call_lightning_module_hook("on_test_model_eval")
else:
self.trainer._call_lightning_module_hook("on_validation_model_eval")
def _on_evaluation_model_train(self) -> None:
"""Sets model to train mode."""
if self.trainer.testing:
self.trainer._call_lightning_module_hook("on_test_model_train")
else:
self.trainer._call_lightning_module_hook("on_validation_model_train")
def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_end`` hook."""
if self.trainer.testing:
self.trainer._call_callback_hooks("on_test_end", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_test_end", *args, **kwargs)
self.trainer._call_strategy_hook("on_test_end", *args, **kwargs)
else:
self.trainer._call_callback_hooks("on_validation_end", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_validation_end", *args, **kwargs)
self.trainer._call_strategy_hook("on_validation_end", *args, **kwargs)
# reset the logger connector state
self.trainer.logger_connector.reset_results()
def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks."""
self.trainer.logger_connector.on_epoch_start()
self.trainer._call_callback_hooks("on_epoch_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_epoch_start", *args, **kwargs)
if self.trainer.testing:
self.trainer._call_callback_hooks("on_test_epoch_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_test_epoch_start", *args, **kwargs)
else:
self.trainer._call_callback_hooks("on_validation_epoch_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_validation_epoch_start", *args, **kwargs)
def _evaluation_epoch_end(self, outputs: List[EPOCH_OUTPUT]) -> None:
"""Runs ``{validation/test}_epoch_end``"""
self.trainer.logger_connector._evaluation_epoch_end()
# with a single dataloader don't pass a 2D list
output_or_outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] = (
outputs[0] if len(outputs) > 0 and self.num_dataloaders == 1 else outputs
)
# call the model epoch end
if self.trainer.testing:
self.trainer._call_lightning_module_hook("test_epoch_end", output_or_outputs)
else:
self.trainer._call_lightning_module_hook("validation_epoch_end", output_or_outputs)
def _on_evaluation_epoch_end(self) -> None:
"""Runs ``on_{validation/test}_epoch_end`` hook."""
hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"
self.trainer._call_callback_hooks(hook_name)
self.trainer._call_lightning_module_hook(hook_name)
self.trainer._call_callback_hooks("on_epoch_end")
self.trainer._call_lightning_module_hook("on_epoch_end")
self.trainer.logger_connector.on_epoch_end()
def _print_results(self, results: List[_OUT_DICT], stage: RunningStage) -> None:
# TODO: this could be updated to look nicer
from pprint import pprint
print("-" * 80)
for i, metrics_dict in enumerate(results):
print(f"DATALOADER:{i} {stage.upper()} RESULTS")
pprint(
{
k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v
for k, v in metrics_dict.items()
}
)
print("-" * 80)