From 8394770d4afa5480f881229b150ac44eaa8c41b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 9 Feb 2022 15:34:24 +0100 Subject: [PATCH] Move data fetcher ownership to the loops (#11621) --- CHANGELOG.md | 3 + .../loops/dataloader/evaluation_loop.py | 16 +++- pytorch_lightning/loops/fit_loop.py | 42 +++++++++- .../trainer/connectors/data_connector.py | 79 +------------------ pytorch_lightning/trainer/trainer.py | 1 - pytorch_lightning/utilities/fetching.py | 8 +- tests/utilities/test_fetching.py | 6 +- 7 files changed, 64 insertions(+), 91 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 65bfeee76b..167bc7df4a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -151,6 +151,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Moved ownership of the lightning optimizers from the `Trainer` to the `Strategy` ([#11444](https://github.com/PyTorchLightning/pytorch-lightning/pull/11444)) +- Moved ownership of the data fetchers from the DataConnector to the Loops ([#11621](https://github.com/PyTorchLightning/pytorch-lightning/pull/11621)) + + - Moved `batch_to_device` method from `Accelerator` to `TrainingTypePlugin` ([#10649](https://github.com/PyTorchLightning/pytorch-lightning/pull/10649)) diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index d8ce9b605f..aab5356cba 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -14,6 +14,7 @@ import os import shutil from collections import OrderedDict +from functools import partial from typing import Any, IO, Iterable, List, Optional, Sequence, Union import torch @@ -25,6 +26,7 @@ from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher from pytorch_lightning.utilities.imports import _RICH_AVAILABLE from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -46,6 +48,7 @@ class EvaluationLoop(DataLoaderLoop): self._logged_outputs: List[_OUT_DICT] = [] self._max_batches: List[int] = [] self._has_run: bool = False + self._data_fetcher: Optional[AbstractDataFetcher] = None @property def num_dataloaders(self) -> int: @@ -107,6 +110,8 @@ class EvaluationLoop(DataLoaderLoop): hooks.""" void(*args, **kwargs) + self._data_fetcher = DataFetcher() + # hook self._on_evaluation_model_eval() self.trainer.lightning_module.zero_grad() @@ -119,15 +124,17 @@ class EvaluationLoop(DataLoaderLoop): dataloader_idx = self.current_dataloader_idx dataloader = self.trainer.strategy.process_dataloader(self.current_dataloader) - self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader( - dataloader, dataloader_idx=dataloader_idx + assert self._data_fetcher is not None + self._data_fetcher.setup( + dataloader, + batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=dataloader_idx), ) dl_max_batches = self._max_batches[dataloader_idx] kwargs = OrderedDict() if self.num_dataloaders > 1: kwargs["dataloader_idx"] = dataloader_idx - dl_outputs = self.epoch_loop.run(dataloader, dl_max_batches, kwargs) + dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs) # store batch level output per dataloader self._outputs.append(dl_outputs) @@ -177,6 +184,9 @@ class EvaluationLoop(DataLoaderLoop): return logged_outputs def teardown(self) -> None: + if self._data_fetcher is not None: + self._data_fetcher.teardown() + self._data_fetcher = None self._results.cpu() self.epoch_loop.teardown() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 4954a2b74d..8bde81ea70 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -13,8 +13,12 @@ # limitations under the License. import logging import math -from typing import Optional +import os +from functools import partial +from typing import Optional, Type +import pytorch_lightning as pl +from pytorch_lightning.accelerators import GPUAccelerator from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE @@ -24,8 +28,15 @@ from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities.enums import _FaultTolerantMode from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.fetching import ( + AbstractDataFetcher, + DataFetcher, + DataLoaderIterDataFetcher, + InterBatchParallelDataFetcher, +) from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature log = logging.getLogger(__name__) @@ -57,6 +68,7 @@ class FitLoop(Loop[None]): self._is_fresh_start_epoch: bool = True self._outputs: _EPOCH_OUTPUTS_TYPE = [] + self._data_fetcher: Optional[AbstractDataFetcher] = None @property def global_step(self) -> int: @@ -183,6 +195,8 @@ class FitLoop(Loop[None]): """Calls the ``on_train_start`` hook.""" # reset train dataloader and val dataloader self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module) + data_fetcher_cls = _select_data_fetcher(self.trainer) + self._data_fetcher = data_fetcher_cls() ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled if not ft_enabled and self.restarting and self.trainer.num_training_batches not in (0, float("inf")): @@ -203,6 +217,7 @@ class FitLoop(Loop[None]): self._is_fresh_start_epoch = True self._results.to(device=self.trainer.lightning_module.device) + self.trainer._call_callback_hooks("on_train_start") self.trainer._call_lightning_module_hook("on_train_start") self.trainer._call_strategy_hook("on_train_start") @@ -250,10 +265,11 @@ class FitLoop(Loop[None]): log.detail(f"{self.__class__.__name__}: advancing loop") assert self.trainer.train_dataloader is not None dataloader = self.trainer.strategy.process_dataloader(self.trainer.train_dataloader) - data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader, 0) - + self._data_fetcher.setup( + dataloader, batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=0) + ) with self.trainer.profiler.profile("run_training_epoch"): - self._outputs = self.epoch_loop.run(data_fetcher) + self._outputs = self.epoch_loop.run(self._data_fetcher) def on_advance_end(self) -> None: # inform logger the batch loop has finished @@ -324,8 +340,26 @@ class FitLoop(Loop[None]): self.trainer.strategy.on_train_end() def teardown(self) -> None: + if self._data_fetcher is not None: + self._data_fetcher.teardown() + self._data_fetcher = None self.epoch_loop.teardown() def _should_accumulate(self) -> bool: """Whether the gradients should be accumulated.""" return self.epoch_loop._should_accumulate() + + +def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]: + training_step_fx = getattr(trainer.lightning_module, "training_step") + if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): + rank_zero_warn( + "Found `dataloader_iter` argument in the `training_step`. Note that the support for " + "this signature is experimental and the behavior is subject to change." + ) + return DataLoaderIterDataFetcher + elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": + if not isinstance(trainer.accelerator, GPUAccelerator): + raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.") + return InterBatchParallelDataFetcher + return DataFetcher diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 2d508de8d0..ef79bd88db 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -14,23 +14,18 @@ import multiprocessing import os from dataclasses import dataclass -from functools import partial -from typing import Any, Collection, Iterable, List, Optional, Tuple, Union +from typing import Any, Collection, List, Optional, Tuple, Union from weakref import proxy from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler import pytorch_lightning as pl -from pytorch_lightning.accelerators import GPUAccelerator from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.auto_restart import ( - _teardown_dataloader_get_iterators, - _validate_fault_tolerant_automatic, -) +from pytorch_lightning.utilities.auto_restart import _validate_fault_tolerant_automatic from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, _is_dataloader_shuffled, @@ -41,48 +36,22 @@ from pytorch_lightning.utilities.data import ( ) from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.fetching import ( - AbstractDataFetcher, - DataFetcher, - DataLoaderIterDataFetcher, - InterBatchParallelDataFetcher, -) from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from pytorch_lightning.utilities.warnings import PossibleUserWarning class DataConnector: - def __init__( - self, - trainer: "pl.Trainer", - multiple_trainloader_mode: str = "max_size_cycle", - train_data_fetcher: Optional[AbstractDataFetcher] = None, - validate_data_fetcher: Optional[AbstractDataFetcher] = None, - test_data_fetcher: Optional[AbstractDataFetcher] = None, - ): + def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"): self.trainer = trainer self.multiple_trainloader_mode = multiple_trainloader_mode - - self.train_data_fetcher = train_data_fetcher - self.validate_data_fetcher = validate_data_fetcher - self.test_data_fetcher = test_data_fetcher - self.sanity_check_data_fetcher: Optional[AbstractDataFetcher] = None - self._train_dataloader_source = _DataLoaderSource(None, "") self._val_dataloader_source = _DataLoaderSource(None, "") self._test_dataloader_source = _DataLoaderSource(None, "") self._predict_dataloader_source = _DataLoaderSource(None, "") - @property - def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]: - if self.trainer.sanity_checking: - return self.sanity_check_data_fetcher - return self.test_data_fetcher if self.trainer.testing else self.validate_data_fetcher - @property def _should_reload_train_dl(self) -> bool: """Check if train dataloader should be reloaded.""" @@ -126,33 +95,6 @@ class DataConnector: self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs self.trainer._is_data_prepared = False - def _select_data_fetcher(self) -> AbstractDataFetcher: - if not self.trainer.training: - return DataFetcher() - - training_step_fx = getattr(self.trainer.lightning_module, "training_step") - if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True): - rank_zero_warn( - "Found `dataloader_iter` argument in the `training_step`. Note that the support for " - "this signature is experimental and the behavior is subject to change." - ) - return DataLoaderIterDataFetcher() - elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1": - if not isinstance(self.trainer.accelerator, GPUAccelerator): - raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.") - return InterBatchParallelDataFetcher() - return DataFetcher() - - def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int) -> Iterable: - stage: str = self.trainer.state.stage.value - data_fetcher = getattr(self, f"{stage}_data_fetcher", None) or self._select_data_fetcher() - data_fetcher.setup( - dataloader, - batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=dataloader_idx), - ) - setattr(self, f"{stage}_data_fetcher", data_fetcher) - return data_fetcher - def prepare_data(self) -> None: # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 # or in the case where each node needs to do its own manipulation in which case just local_rank=0 @@ -559,21 +501,6 @@ class DataConnector: category=PossibleUserWarning, ) - def teardown(self) -> None: - if self.train_data_fetcher: - self.train_data_fetcher.teardown() - self.train_data_fetcher = None - if self.validate_data_fetcher: - self.validate_data_fetcher.teardown() - self.validate_data_fetcher = None - if self.test_data_fetcher: - self.test_data_fetcher.teardown() - self.test_data_fetcher = None - if self.sanity_check_data_fetcher: - self.sanity_check_data_fetcher.teardown() - self.sanity_check_data_fetcher = None - _teardown_dataloader_get_iterators() - @dataclass class _DataLoaderSource: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f0b56e35e1..91d319a113 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1248,7 +1248,6 @@ class Trainer( Callback; those are handled by :meth:`_call_teardown_hook`.""" self.strategy.post_dispatch(self) self.strategy.teardown() - self._data_connector.teardown() loop = self._active_loop # loop should never be `None` here but it can because we don't know the trainer stage with `ddp_spawn` if loop is not None: diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index e523bf7a35..0b8b127661 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -188,10 +188,10 @@ class AbstractDataFetcher(ABC): def teardown(self) -> None: self.reset() - if isinstance(self.dataloader, CombinedLoader): - self.dataloader.reset() - if isinstance(self.dataloader, DataLoader): - CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader) + if isinstance(self._dataloader, CombinedLoader): + self._dataloader.reset() + if isinstance(self._dataloader, DataLoader): + CombinedLoader._shutdown_workers_and_reset_iterator(self._dataloader) self.dataloader_iter = None _teardown_dataloader_get_iterators() diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index aa01c87041..4be1bed1e3 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -184,7 +184,7 @@ def test_trainer_num_prefetch_batches(tmpdir): self._check_inter_batch = check_inter_batch def on_train_epoch_end(self, trainer, lightning_module): - fetcher = trainer._data_connector.train_data_fetcher + fetcher = trainer.fit_loop._data_fetcher assert isinstance(fetcher, InterBatchParallelDataFetcher if self._check_inter_batch else DataFetcher) assert fetcher.prefetch_batches == 1 @@ -232,7 +232,7 @@ def test_fetching_dataloader_iter(automatic_optimization, tmpdir): def training_step(self, dataloader_iter, batch_idx): assert self.count == batch_idx - assert isinstance(self.trainer._data_connector.train_data_fetcher, DataLoaderIterDataFetcher) + assert isinstance(self.trainer.fit_loop._data_fetcher, DataLoaderIterDataFetcher) # fetch 2 batches self.batches.append(next(dataloader_iter)) self.batches.append(next(dataloader_iter)) @@ -255,7 +255,7 @@ def test_fetching_dataloader_iter(automatic_optimization, tmpdir): def training_epoch_end(self, *_): assert self.trainer.fit_loop.epoch_loop.batch_progress.current.ready == 33 - assert self.trainer._data_connector.train_data_fetcher.fetched == 64 + assert self.trainer.fit_loop._data_fetcher.fetched == 64 assert self.count == 64 model = TestModel(automatic_optimization=automatic_optimization)