lightning/pytorch_lightning/loops/dataloader/evaluation_loop.py

432 lines
18 KiB
Python
Raw Normal View History

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import sys
from collections import ChainMap, OrderedDict
from functools import partial
from typing import Any, IO, Iterable, List, Optional, Sequence, Type, Union
import torch
from deprecate.utils import void
from torch.utils.data.dataloader import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import (
AbstractDataFetcher,
DataFetcher,
DataLoaderIterDataFetcher,
InterBatchParallelDataFetcher,
)
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
if _RICH_AVAILABLE:
from rich.console import Console
from rich.table import Column, Table
class EvaluationLoop(DataLoaderLoop):
"""Loops over all dataloaders for evaluation."""
def __init__(self, verbose: bool = True) -> None:
super().__init__()
self.epoch_loop = EvaluationEpochLoop()
self.verbose = verbose
self._results = _ResultCollection(training=False)
self._outputs: List[EPOCH_OUTPUT] = []
self._logged_outputs: List[_OUT_DICT] = []
self._max_batches: List[int] = []
self._has_run: bool = False
self._data_fetcher: Optional[AbstractDataFetcher] = None
@property
def num_dataloaders(self) -> int:
"""Returns the total number of dataloaders."""
# case where user does:
# return dl1, dl2
dataloaders = self.dataloaders
length = len(dataloaders)
if length > 0 and isinstance(dataloaders[0], (list, tuple)):
length = len(dataloaders[0])
return length
@property
def dataloaders(self) -> Sequence[DataLoader]:
"""Returns the validation or test dataloaders."""
dataloaders = self.trainer.test_dataloaders if self.trainer.testing else self.trainer.val_dataloaders
if dataloaders is None:
return []
return dataloaders
2022-02-28 18:31:18 +00:00
@property
def prefetch_batches(self) -> int:
batches = self.trainer.num_test_batches if self.trainer.testing else self.trainer.num_val_batches
is_unsized = batches[self.current_dataloader_idx] == float("inf")
inter_batch_parallelism = os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1"
return 1 if is_unsized or inter_batch_parallelism else 0
def connect(self, epoch_loop: EvaluationEpochLoop) -> None: # type: ignore[override]
"""Connect the evaluation epoch loop with this loop."""
self.epoch_loop = epoch_loop
@property
def done(self) -> bool:
"""Returns whether all dataloaders are processed or evaluation should be skipped altogether."""
Add progress tracking on Loops - 2/n (#8362) * resolve issues * update * update * update * add more exceptions * resolve bug * update * update * update changelog * resolve bug * resolve comments * update * update * update changelog * update * update * remove space * update * add progress tracking to loops * validate json * update * convert to dict for better readability * validate reload * update * update * update on comments * remove deadcode * clean changelog * clean changelog * update * update on comments * CHANGELOG * CHANGELOG * Update pytorch_lightning/loops/base.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * whitespace suggestions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make fault_tolerant_enabled protected * whitespace fixes around Args * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * typo it's -> its * fix copy-paste typo in progress docstring * Delete classes * Minor change * docs * protected get_loops_state * merge restore_loops with restore_progress * Fix tests after removals * explicit save with trainer.save_checkpoint() * handle optimization restart based on optimizer_idx * update increments * update val batch progress and remove iteration count * update progress tracking for dataloader loops * remove self.dataloader_idx from eval_epoch_loop * add batch progress to predict loop * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * incorporate progress tracking for current_epoch * Fix test * Actually remove it * Remove unused TrainingEpochProgress * Fix optimization progress - missing scheduler * Restarting changes * Scheduler progress * Unused property, reset on epoch * Resolve FIXME * Remove FIXME * fix test_progress (wip) * fix batch_progress.current.reset * Hold off on split progress. Out of scope of this PR * Unnecessary if * fix structure in test_progress * structure * clean up unused variables in test_progress * refactor naming and organization in test_progress * Unnecessary variable * Remove unnecessary diff * Improve comment * Undo typing change to avoid polluting everything with mypy fixes * Fix and improve test_loops.py * Fix and organize `test_loop_state_dict` * Remove unnecessary checks in test * Update test after disallowing updates on None attributes * Typing * Minor test cleanup * Fix and move loop test * Move test from progress to loops * Reset the scheduler progress * SchedulerProgress fix * Consistent whitespace * Fix final test * Minor test changes * One test to rule them all * Formatting * Rename and clean variables * Shorter names * Shorter scheduler name * Fix optimizer step calculation for stop_batch=2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove empty connects * Update CHANGELOG * Holy shit finally got the formula right * Fix final thing!!! * Do not check state dicts * parametrize multiple_dataloader progress test * Update CHANGELOG.md Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <justus.schock@posteo.de>
2021-07-19 08:31:45 +00:00
return super().done or self.skip
@property
def skip(self) -> bool:
"""Returns whether the evaluation should be skipped."""
max_batches = self._get_max_batches()
return sum(max_batches) == 0
def reset(self) -> None:
"""Resets the internal state of the loop."""
self._max_batches = self._get_max_batches()
# bookkeeping
self._outputs = []
self._logged_outputs = []
if isinstance(self._max_batches, int):
self._max_batches = [self._max_batches] * len(self.dataloaders)
Add progress tracking on Loops - 2/n (#8362) * resolve issues * update * update * update * add more exceptions * resolve bug * update * update * update changelog * resolve bug * resolve comments * update * update * update changelog * update * update * remove space * update * add progress tracking to loops * validate json * update * convert to dict for better readability * validate reload * update * update * update on comments * remove deadcode * clean changelog * clean changelog * update * update on comments * CHANGELOG * CHANGELOG * Update pytorch_lightning/loops/base.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * whitespace suggestions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make fault_tolerant_enabled protected * whitespace fixes around Args * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * typo it's -> its * fix copy-paste typo in progress docstring * Delete classes * Minor change * docs * protected get_loops_state * merge restore_loops with restore_progress * Fix tests after removals * explicit save with trainer.save_checkpoint() * handle optimization restart based on optimizer_idx * update increments * update val batch progress and remove iteration count * update progress tracking for dataloader loops * remove self.dataloader_idx from eval_epoch_loop * add batch progress to predict loop * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * incorporate progress tracking for current_epoch * Fix test * Actually remove it * Remove unused TrainingEpochProgress * Fix optimization progress - missing scheduler * Restarting changes * Scheduler progress * Unused property, reset on epoch * Resolve FIXME * Remove FIXME * fix test_progress (wip) * fix batch_progress.current.reset * Hold off on split progress. Out of scope of this PR * Unnecessary if * fix structure in test_progress * structure * clean up unused variables in test_progress * refactor naming and organization in test_progress * Unnecessary variable * Remove unnecessary diff * Improve comment * Undo typing change to avoid polluting everything with mypy fixes * Fix and improve test_loops.py * Fix and organize `test_loop_state_dict` * Remove unnecessary checks in test * Update test after disallowing updates on None attributes * Typing * Minor test cleanup * Fix and move loop test * Move test from progress to loops * Reset the scheduler progress * SchedulerProgress fix * Consistent whitespace * Fix final test * Minor test changes * One test to rule them all * Formatting * Rename and clean variables * Shorter names * Shorter scheduler name * Fix optimizer step calculation for stop_batch=2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove empty connects * Update CHANGELOG * Holy shit finally got the formula right * Fix final thing!!! * Do not check state dicts * parametrize multiple_dataloader progress test * Update CHANGELOG.md Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <justus.schock@posteo.de>
2021-07-19 08:31:45 +00:00
super().reset()
# when restarting, if we are running `validate` or `test` twice, since there's no concept of `max_epochs` we
# need to reset the current state when the loop has finished running
if self.done and self.trainer.state.fn != TrainerFn.FITTING:
self.dataloader_progress.reset_on_run()
Add progress tracking on Loops - 2/n (#8362) * resolve issues * update * update * update * add more exceptions * resolve bug * update * update * update changelog * resolve bug * resolve comments * update * update * update changelog * update * update * remove space * update * add progress tracking to loops * validate json * update * convert to dict for better readability * validate reload * update * update * update on comments * remove deadcode * clean changelog * clean changelog * update * update on comments * CHANGELOG * CHANGELOG * Update pytorch_lightning/loops/base.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * whitespace suggestions * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make fault_tolerant_enabled protected * whitespace fixes around Args * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update * typo it's -> its * fix copy-paste typo in progress docstring * Delete classes * Minor change * docs * protected get_loops_state * merge restore_loops with restore_progress * Fix tests after removals * explicit save with trainer.save_checkpoint() * handle optimization restart based on optimizer_idx * update increments * update val batch progress and remove iteration count * update progress tracking for dataloader loops * remove self.dataloader_idx from eval_epoch_loop * add batch progress to predict loop * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * incorporate progress tracking for current_epoch * Fix test * Actually remove it * Remove unused TrainingEpochProgress * Fix optimization progress - missing scheduler * Restarting changes * Scheduler progress * Unused property, reset on epoch * Resolve FIXME * Remove FIXME * fix test_progress (wip) * fix batch_progress.current.reset * Hold off on split progress. Out of scope of this PR * Unnecessary if * fix structure in test_progress * structure * clean up unused variables in test_progress * refactor naming and organization in test_progress * Unnecessary variable * Remove unnecessary diff * Improve comment * Undo typing change to avoid polluting everything with mypy fixes * Fix and improve test_loops.py * Fix and organize `test_loop_state_dict` * Remove unnecessary checks in test * Update test after disallowing updates on None attributes * Typing * Minor test cleanup * Fix and move loop test * Move test from progress to loops * Reset the scheduler progress * SchedulerProgress fix * Consistent whitespace * Fix final test * Minor test changes * One test to rule them all * Formatting * Rename and clean variables * Shorter names * Shorter scheduler name * Fix optimizer step calculation for stop_batch=2 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove empty connects * Update CHANGELOG * Holy shit finally got the formula right * Fix final thing!!! * Do not check state dicts * parametrize multiple_dataloader progress test * Update CHANGELOG.md Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Justus Schock <justus.schock@posteo.de>
2021-07-19 08:31:45 +00:00
def on_skip(self) -> List:
return []
def on_run_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
hooks."""
void(*args, **kwargs)
data_fetcher_cls = _select_data_fetcher_type(self.trainer)
2022-02-28 18:31:18 +00:00
self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches)
# hook
self._on_evaluation_model_eval()
self.trainer.lightning_module.zero_grad()
self._on_evaluation_start()
self._on_evaluation_epoch_start()
def advance(self, *args: Any, **kwargs: Any) -> None:
"""Performs evaluation on one single dataloader."""
void(*args, **kwargs)
dataloader_idx = self.current_dataloader_idx
dataloader = self.current_dataloader
assert self._data_fetcher is not None
self._data_fetcher.setup(
dataloader,
batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=dataloader_idx),
2021-10-29 16:29:44 +00:00
)
dl_max_batches = self._max_batches[dataloader_idx]
kwargs = OrderedDict()
if self.num_dataloaders > 1:
kwargs["dataloader_idx"] = dataloader_idx
dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
# store batch level output per dataloader
self._outputs.append(dl_outputs)
if not self.trainer.sanity_checking:
# indicate the loop has run
self._has_run = True
def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
dataloader = self.current_dataloader
if (
dataloader is not None
and getattr(dataloader, "sampler", None)
and callable(getattr(dataloader.sampler, "set_epoch", None))
):
# set seed for distributed sampler (enables shuffling for each epoch)
dataloader.sampler.set_epoch(self.trainer.fit_loop.epoch_progress.current.processed)
super().on_advance_start(*args, **kwargs)
def on_advance_end(self) -> None:
self.trainer._logger_connector.epoch_end_reached()
self._logged_outputs.append(self.trainer._logger_connector.update_eval_epoch_metrics())
super().on_advance_end()
def on_run_end(self) -> List[_OUT_DICT]:
"""Runs the ``_on_evaluation_epoch_end`` hook."""
# if `done` returned True before any iterations were done, this won't have been called in `on_advance_end`
self.trainer._logger_connector.epoch_end_reached()
# hook
self._evaluation_epoch_end(self._outputs)
self._outputs = [] # free memory
# hook
self._on_evaluation_epoch_end()
logged_outputs, self._logged_outputs = self._logged_outputs, [] # free memory
# include any logged outputs on epoch_end
epoch_end_logged_outputs = self.trainer._logger_connector.update_eval_epoch_metrics()
all_logged_outputs = dict(ChainMap(*logged_outputs)) # list[dict] -> dict
all_logged_outputs.update(epoch_end_logged_outputs)
for dl_outputs in logged_outputs:
dl_outputs.update(epoch_end_logged_outputs)
# log metrics
self.trainer._logger_connector.log_eval_end_metrics(all_logged_outputs)
# hook
self._on_evaluation_end()
# enable train mode again
self._on_evaluation_model_train()
if self.verbose and self.trainer.is_global_zero:
assert self.trainer.state.stage is not None
self._print_results(logged_outputs, self.trainer.state.stage)
return logged_outputs
def teardown(self) -> None:
if self._data_fetcher is not None:
self._data_fetcher.teardown()
self._data_fetcher = None
self._results.cpu()
self.epoch_loop.teardown()
def _get_max_batches(self) -> List[int]:
"""Returns the max number of batches for each dataloader."""
if self.trainer.testing:
max_batches = self.trainer.num_test_batches
else:
if self.trainer.sanity_checking:
max_batches = self.trainer.num_sanity_val_batches
else:
max_batches = self.trainer.num_val_batches
return max_batches
def _reload_evaluation_dataloaders(self) -> None:
"""Reloads dataloaders if necessary."""
if self.trainer.testing:
self.trainer.reset_test_dataloader()
elif self.trainer.val_dataloaders is None or self.trainer._data_connector._should_reload_val_dl:
self.trainer.reset_val_dataloader()
def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_start`` hooks."""
assert self._results is not None
self._results.to(device=self.trainer.lightning_module.device)
if self.trainer.testing:
self.trainer._call_callback_hooks("on_test_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_test_start", *args, **kwargs)
self.trainer._call_strategy_hook("on_test_start", *args, **kwargs)
else:
self.trainer._call_callback_hooks("on_validation_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_validation_start", *args, **kwargs)
self.trainer._call_strategy_hook("on_validation_start", *args, **kwargs)
def _on_evaluation_model_eval(self) -> None:
"""Sets model to eval mode."""
if self.trainer.testing:
self.trainer._call_lightning_module_hook("on_test_model_eval")
else:
self.trainer._call_lightning_module_hook("on_validation_model_eval")
def _on_evaluation_model_train(self) -> None:
"""Sets model to train mode."""
if self.trainer.testing:
self.trainer._call_lightning_module_hook("on_test_model_train")
else:
self.trainer._call_lightning_module_hook("on_validation_model_train")
def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_end`` hook."""
if self.trainer.testing:
self.trainer._call_callback_hooks("on_test_end", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_test_end", *args, **kwargs)
self.trainer._call_strategy_hook("on_test_end", *args, **kwargs)
else:
self.trainer._call_callback_hooks("on_validation_end", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_validation_end", *args, **kwargs)
self.trainer._call_strategy_hook("on_validation_end", *args, **kwargs)
# reset the logger connector state
self.trainer._logger_connector.reset_results()
def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks."""
self.trainer._logger_connector.on_epoch_start()
self.trainer._call_callback_hooks("on_epoch_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_epoch_start", *args, **kwargs)
if self.trainer.testing:
self.trainer._call_callback_hooks("on_test_epoch_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_test_epoch_start", *args, **kwargs)
else:
self.trainer._call_callback_hooks("on_validation_epoch_start", *args, **kwargs)
self.trainer._call_lightning_module_hook("on_validation_epoch_start", *args, **kwargs)
def _evaluation_epoch_end(self, outputs: List[EPOCH_OUTPUT]) -> None:
"""Runs ``{validation/test}_epoch_end``"""
self.trainer._logger_connector._evaluation_epoch_end()
# with a single dataloader don't pass a 2D list
output_or_outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]] = (
outputs[0] if len(outputs) > 0 and self.num_dataloaders == 1 else outputs
)
# call the model epoch end
if self.trainer.testing:
self.trainer._call_lightning_module_hook("test_epoch_end", output_or_outputs)
else:
self.trainer._call_lightning_module_hook("validation_epoch_end", output_or_outputs)
def _on_evaluation_epoch_end(self) -> None:
"""Runs ``on_{validation/test}_epoch_end`` hook."""
hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"
self.trainer._call_callback_hooks(hook_name)
self.trainer._call_lightning_module_hook(hook_name)
self.trainer._call_callback_hooks("on_epoch_end")
self.trainer._call_lightning_module_hook("on_epoch_end")
self.trainer._logger_connector.on_epoch_end()
@staticmethod
def _get_keys(data: dict) -> Iterable[str]:
if any(isinstance(v, dict) for v in data.values()):
for v in data.values():
yield from apply_to_collection(v, dict, dict.keys)
else:
yield from data.keys()
@staticmethod
def _find_value(data: dict, target: str) -> Iterable[Any]:
for k, v in data.items():
if k == target:
yield v
elif isinstance(v, dict):
yield from EvaluationLoop._find_value(v, target)
@staticmethod
def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] = None) -> None:
# print to stdout by default
if file is None:
file = sys.stdout
# remove the dl idx suffix
results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results]
metrics = sorted({k for keys in apply_to_collection(results, dict, EvaluationLoop._get_keys) for k in keys})
if not metrics:
return
headers = [f"DataLoader {i}" for i in range(len(results))]
# fallback is useful for testing of printed output
term_size = shutil.get_terminal_size(fallback=(120, 30)).columns or 120
max_length = int(min(max(len(max(metrics + headers, key=len)), 25), term_size / 2))
rows: List[List[Any]] = [[] for _ in metrics]
for result in results:
for metric, row in zip(metrics, rows):
v = list(EvaluationLoop._find_value(result, metric))
if v:
val = v[0]
if isinstance(val, torch.Tensor):
val = val.item() if val.numel() == 1 else val.tolist()
row.append(f"{val}")
else:
row.append(" ")
# keep one column with max length for metrics
num_cols = int((term_size - max_length) / max_length)
for i in range(0, len(headers), num_cols):
table_headers = headers[i : (i + num_cols)]
table_rows = [row[i : (i + num_cols)] for row in rows]
table_headers.insert(0, f"{stage} Metric".capitalize())
if _RICH_AVAILABLE:
console = Console(file=file)
columns = [Column(h, justify="center", style="magenta", width=max_length) for h in table_headers]
columns[0].style = "cyan"
table = Table(*columns)
for metric, row in zip(metrics, table_rows):
row.insert(0, metric)
table.add_row(*row)
console.print(table)
else:
row_format = f"{{:^{max_length}}}" * len(table_headers)
half_term_size = int(term_size / 2)
try:
# some terminals do not support this character
if hasattr(file, "encoding") and file.encoding is not None:
"".encode(file.encoding)
except UnicodeEncodeError:
bar_character = "-"
else:
bar_character = ""
bar = bar_character * term_size
lines = [bar, row_format.format(*table_headers).rstrip(), bar]
for metric, row in zip(metrics, table_rows):
# deal with column overflow
if len(metric) > half_term_size:
while len(metric) > half_term_size:
row_metric = metric[:half_term_size]
metric = metric[half_term_size:]
lines.append(row_format.format(row_metric, *row).rstrip())
lines.append(row_format.format(metric, " ").rstrip())
else:
lines.append(row_format.format(metric, *row).rstrip())
lines.append(bar)
print(os.linesep.join(lines), file=file)
def _select_data_fetcher_type(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]:
lightning_module = trainer.lightning_module
step_fx_name = "test_step" if trainer.testing else "validation_step"
step_fx = getattr(lightning_module, step_fx_name)
if is_param_in_hook_signature(step_fx, "dataloader_iter", explicit=True):
rank_zero_warn(
f"Found `dataloader_iter` argument in the `{step_fx_name}`. Note that the support for "
"this signature is experimental and the behavior is subject to change."
)
return DataLoaderIterDataFetcher
elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
if not isinstance(trainer.accelerator, GPUAccelerator):
raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")
return InterBatchParallelDataFetcher
return DataFetcher