Move data fetcher ownership to the loops (#11621)

This commit is contained in:
Carlos Mocholí 2022-02-09 15:34:24 +01:00 committed by GitHub
parent 24de29974c
commit 8394770d4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 64 additions and 91 deletions

View File

@ -151,6 +151,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Moved ownership of the lightning optimizers from the `Trainer` to the `Strategy` ([#11444](https://github.com/PyTorchLightning/pytorch-lightning/pull/11444))
- Moved ownership of the data fetchers from the DataConnector to the Loops ([#11621](https://github.com/PyTorchLightning/pytorch-lightning/pull/11621))
- Moved `batch_to_device` method from `Accelerator` to `TrainingTypePlugin` ([#10649](https://github.com/PyTorchLightning/pytorch-lightning/pull/10649))

View File

@ -14,6 +14,7 @@
import os
import shutil
from collections import OrderedDict
from functools import partial
from typing import Any, IO, Iterable, List, Optional, Sequence, Union
import torch
@ -25,6 +26,7 @@ from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
@ -46,6 +48,7 @@ class EvaluationLoop(DataLoaderLoop):
self._logged_outputs: List[_OUT_DICT] = []
self._max_batches: List[int] = []
self._has_run: bool = False
self._data_fetcher: Optional[AbstractDataFetcher] = None
@property
def num_dataloaders(self) -> int:
@ -107,6 +110,8 @@ class EvaluationLoop(DataLoaderLoop):
hooks."""
void(*args, **kwargs)
self._data_fetcher = DataFetcher()
# hook
self._on_evaluation_model_eval()
self.trainer.lightning_module.zero_grad()
@ -119,15 +124,17 @@ class EvaluationLoop(DataLoaderLoop):
dataloader_idx = self.current_dataloader_idx
dataloader = self.trainer.strategy.process_dataloader(self.current_dataloader)
self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader(
dataloader, dataloader_idx=dataloader_idx
assert self._data_fetcher is not None
self._data_fetcher.setup(
dataloader,
batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=dataloader_idx),
)
dl_max_batches = self._max_batches[dataloader_idx]
kwargs = OrderedDict()
if self.num_dataloaders > 1:
kwargs["dataloader_idx"] = dataloader_idx
dl_outputs = self.epoch_loop.run(dataloader, dl_max_batches, kwargs)
dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
# store batch level output per dataloader
self._outputs.append(dl_outputs)
@ -177,6 +184,9 @@ class EvaluationLoop(DataLoaderLoop):
return logged_outputs
def teardown(self) -> None:
if self._data_fetcher is not None:
self._data_fetcher.teardown()
self._data_fetcher = None
self._results.cpu()
self.epoch_loop.teardown()

View File

@ -13,8 +13,12 @@
# limitations under the License.
import logging
import math
from typing import Optional
import os
from functools import partial
from typing import Optional, Type
import pytorch_lightning as pl
from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.epoch import TrainingEpochLoop
from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE
@ -24,8 +28,15 @@ from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import (
AbstractDataFetcher,
DataFetcher,
DataLoaderIterDataFetcher,
InterBatchParallelDataFetcher,
)
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
log = logging.getLogger(__name__)
@ -57,6 +68,7 @@ class FitLoop(Loop[None]):
self._is_fresh_start_epoch: bool = True
self._outputs: _EPOCH_OUTPUTS_TYPE = []
self._data_fetcher: Optional[AbstractDataFetcher] = None
@property
def global_step(self) -> int:
@ -183,6 +195,8 @@ class FitLoop(Loop[None]):
"""Calls the ``on_train_start`` hook."""
# reset train dataloader and val dataloader
self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)
data_fetcher_cls = _select_data_fetcher(self.trainer)
self._data_fetcher = data_fetcher_cls()
ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
if not ft_enabled and self.restarting and self.trainer.num_training_batches not in (0, float("inf")):
@ -203,6 +217,7 @@ class FitLoop(Loop[None]):
self._is_fresh_start_epoch = True
self._results.to(device=self.trainer.lightning_module.device)
self.trainer._call_callback_hooks("on_train_start")
self.trainer._call_lightning_module_hook("on_train_start")
self.trainer._call_strategy_hook("on_train_start")
@ -250,10 +265,11 @@ class FitLoop(Loop[None]):
log.detail(f"{self.__class__.__name__}: advancing loop")
assert self.trainer.train_dataloader is not None
dataloader = self.trainer.strategy.process_dataloader(self.trainer.train_dataloader)
data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader, 0)
self._data_fetcher.setup(
dataloader, batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=0)
)
with self.trainer.profiler.profile("run_training_epoch"):
self._outputs = self.epoch_loop.run(data_fetcher)
self._outputs = self.epoch_loop.run(self._data_fetcher)
def on_advance_end(self) -> None:
# inform logger the batch loop has finished
@ -324,8 +340,26 @@ class FitLoop(Loop[None]):
self.trainer.strategy.on_train_end()
def teardown(self) -> None:
if self._data_fetcher is not None:
self._data_fetcher.teardown()
self._data_fetcher = None
self.epoch_loop.teardown()
def _should_accumulate(self) -> bool:
"""Whether the gradients should be accumulated."""
return self.epoch_loop._should_accumulate()
def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]:
training_step_fx = getattr(trainer.lightning_module, "training_step")
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
rank_zero_warn(
"Found `dataloader_iter` argument in the `training_step`. Note that the support for "
"this signature is experimental and the behavior is subject to change."
)
return DataLoaderIterDataFetcher
elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
if not isinstance(trainer.accelerator, GPUAccelerator):
raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")
return InterBatchParallelDataFetcher
return DataFetcher

View File

@ -14,23 +14,18 @@
import multiprocessing
import os
from dataclasses import dataclass
from functools import partial
from typing import Any, Collection, Iterable, List, Optional, Tuple, Union
from typing import Any, Collection, List, Optional, Tuple, Union
from weakref import proxy
from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
import pytorch_lightning as pl
from pytorch_lightning.accelerators import GPUAccelerator
from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import (
_teardown_dataloader_get_iterators,
_validate_fault_tolerant_automatic,
)
from pytorch_lightning.utilities.auto_restart import _validate_fault_tolerant_automatic
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_is_dataloader_shuffled,
@ -41,48 +36,22 @@ from pytorch_lightning.utilities.data import (
)
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import (
AbstractDataFetcher,
DataFetcher,
DataLoaderIterDataFetcher,
InterBatchParallelDataFetcher,
)
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from pytorch_lightning.utilities.warnings import PossibleUserWarning
class DataConnector:
def __init__(
self,
trainer: "pl.Trainer",
multiple_trainloader_mode: str = "max_size_cycle",
train_data_fetcher: Optional[AbstractDataFetcher] = None,
validate_data_fetcher: Optional[AbstractDataFetcher] = None,
test_data_fetcher: Optional[AbstractDataFetcher] = None,
):
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
self.trainer = trainer
self.multiple_trainloader_mode = multiple_trainloader_mode
self.train_data_fetcher = train_data_fetcher
self.validate_data_fetcher = validate_data_fetcher
self.test_data_fetcher = test_data_fetcher
self.sanity_check_data_fetcher: Optional[AbstractDataFetcher] = None
self._train_dataloader_source = _DataLoaderSource(None, "")
self._val_dataloader_source = _DataLoaderSource(None, "")
self._test_dataloader_source = _DataLoaderSource(None, "")
self._predict_dataloader_source = _DataLoaderSource(None, "")
@property
def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]:
if self.trainer.sanity_checking:
return self.sanity_check_data_fetcher
return self.test_data_fetcher if self.trainer.testing else self.validate_data_fetcher
@property
def _should_reload_train_dl(self) -> bool:
"""Check if train dataloader should be reloaded."""
@ -126,33 +95,6 @@ class DataConnector:
self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
self.trainer._is_data_prepared = False
def _select_data_fetcher(self) -> AbstractDataFetcher:
if not self.trainer.training:
return DataFetcher()
training_step_fx = getattr(self.trainer.lightning_module, "training_step")
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
rank_zero_warn(
"Found `dataloader_iter` argument in the `training_step`. Note that the support for "
"this signature is experimental and the behavior is subject to change."
)
return DataLoaderIterDataFetcher()
elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
if not isinstance(self.trainer.accelerator, GPUAccelerator):
raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")
return InterBatchParallelDataFetcher()
return DataFetcher()
def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int) -> Iterable:
stage: str = self.trainer.state.stage.value
data_fetcher = getattr(self, f"{stage}_data_fetcher", None) or self._select_data_fetcher()
data_fetcher.setup(
dataloader,
batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=dataloader_idx),
)
setattr(self, f"{stage}_data_fetcher", data_fetcher)
return data_fetcher
def prepare_data(self) -> None:
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
@ -559,21 +501,6 @@ class DataConnector:
category=PossibleUserWarning,
)
def teardown(self) -> None:
if self.train_data_fetcher:
self.train_data_fetcher.teardown()
self.train_data_fetcher = None
if self.validate_data_fetcher:
self.validate_data_fetcher.teardown()
self.validate_data_fetcher = None
if self.test_data_fetcher:
self.test_data_fetcher.teardown()
self.test_data_fetcher = None
if self.sanity_check_data_fetcher:
self.sanity_check_data_fetcher.teardown()
self.sanity_check_data_fetcher = None
_teardown_dataloader_get_iterators()
@dataclass
class _DataLoaderSource:

View File

@ -1248,7 +1248,6 @@ class Trainer(
Callback; those are handled by :meth:`_call_teardown_hook`."""
self.strategy.post_dispatch(self)
self.strategy.teardown()
self._data_connector.teardown()
loop = self._active_loop
# loop should never be `None` here but it can because we don't know the trainer stage with `ddp_spawn`
if loop is not None:

View File

@ -188,10 +188,10 @@ class AbstractDataFetcher(ABC):
def teardown(self) -> None:
self.reset()
if isinstance(self.dataloader, CombinedLoader):
self.dataloader.reset()
if isinstance(self.dataloader, DataLoader):
CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader)
if isinstance(self._dataloader, CombinedLoader):
self._dataloader.reset()
if isinstance(self._dataloader, DataLoader):
CombinedLoader._shutdown_workers_and_reset_iterator(self._dataloader)
self.dataloader_iter = None
_teardown_dataloader_get_iterators()

View File

@ -184,7 +184,7 @@ def test_trainer_num_prefetch_batches(tmpdir):
self._check_inter_batch = check_inter_batch
def on_train_epoch_end(self, trainer, lightning_module):
fetcher = trainer._data_connector.train_data_fetcher
fetcher = trainer.fit_loop._data_fetcher
assert isinstance(fetcher, InterBatchParallelDataFetcher if self._check_inter_batch else DataFetcher)
assert fetcher.prefetch_batches == 1
@ -232,7 +232,7 @@ def test_fetching_dataloader_iter(automatic_optimization, tmpdir):
def training_step(self, dataloader_iter, batch_idx):
assert self.count == batch_idx
assert isinstance(self.trainer._data_connector.train_data_fetcher, DataLoaderIterDataFetcher)
assert isinstance(self.trainer.fit_loop._data_fetcher, DataLoaderIterDataFetcher)
# fetch 2 batches
self.batches.append(next(dataloader_iter))
self.batches.append(next(dataloader_iter))
@ -255,7 +255,7 @@ def test_fetching_dataloader_iter(automatic_optimization, tmpdir):
def training_epoch_end(self, *_):
assert self.trainer.fit_loop.epoch_loop.batch_progress.current.ready == 33
assert self.trainer._data_connector.train_data_fetcher.fetched == 64
assert self.trainer.fit_loop._data_fetcher.fetched == 64
assert self.count == 64
model = TestModel(automatic_optimization=automatic_optimization)