Move data fetcher ownership to the loops (#11621)
This commit is contained in:
parent
24de29974c
commit
8394770d4a
|
@ -151,6 +151,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Moved ownership of the lightning optimizers from the `Trainer` to the `Strategy` ([#11444](https://github.com/PyTorchLightning/pytorch-lightning/pull/11444))
|
||||
|
||||
|
||||
- Moved ownership of the data fetchers from the DataConnector to the Loops ([#11621](https://github.com/PyTorchLightning/pytorch-lightning/pull/11621))
|
||||
|
||||
|
||||
- Moved `batch_to_device` method from `Accelerator` to `TrainingTypePlugin` ([#10649](https://github.com/PyTorchLightning/pytorch-lightning/pull/10649))
|
||||
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
import os
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Any, IO, Iterable, List, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
|
@ -25,6 +26,7 @@ from pytorch_lightning.loops.epoch import EvaluationEpochLoop
|
|||
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher
|
||||
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE
|
||||
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
|
||||
|
||||
|
@ -46,6 +48,7 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
self._logged_outputs: List[_OUT_DICT] = []
|
||||
self._max_batches: List[int] = []
|
||||
self._has_run: bool = False
|
||||
self._data_fetcher: Optional[AbstractDataFetcher] = None
|
||||
|
||||
@property
|
||||
def num_dataloaders(self) -> int:
|
||||
|
@ -107,6 +110,8 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
hooks."""
|
||||
void(*args, **kwargs)
|
||||
|
||||
self._data_fetcher = DataFetcher()
|
||||
|
||||
# hook
|
||||
self._on_evaluation_model_eval()
|
||||
self.trainer.lightning_module.zero_grad()
|
||||
|
@ -119,15 +124,17 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
|
||||
dataloader_idx = self.current_dataloader_idx
|
||||
dataloader = self.trainer.strategy.process_dataloader(self.current_dataloader)
|
||||
self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader(
|
||||
dataloader, dataloader_idx=dataloader_idx
|
||||
assert self._data_fetcher is not None
|
||||
self._data_fetcher.setup(
|
||||
dataloader,
|
||||
batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=dataloader_idx),
|
||||
)
|
||||
dl_max_batches = self._max_batches[dataloader_idx]
|
||||
|
||||
kwargs = OrderedDict()
|
||||
if self.num_dataloaders > 1:
|
||||
kwargs["dataloader_idx"] = dataloader_idx
|
||||
dl_outputs = self.epoch_loop.run(dataloader, dl_max_batches, kwargs)
|
||||
dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
|
||||
|
||||
# store batch level output per dataloader
|
||||
self._outputs.append(dl_outputs)
|
||||
|
@ -177,6 +184,9 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
return logged_outputs
|
||||
|
||||
def teardown(self) -> None:
|
||||
if self._data_fetcher is not None:
|
||||
self._data_fetcher.teardown()
|
||||
self._data_fetcher = None
|
||||
self._results.cpu()
|
||||
self.epoch_loop.teardown()
|
||||
|
||||
|
|
|
@ -13,8 +13,12 @@
|
|||
# limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Optional, Type
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.accelerators import GPUAccelerator
|
||||
from pytorch_lightning.loops import Loop
|
||||
from pytorch_lightning.loops.epoch import TrainingEpochLoop
|
||||
from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE
|
||||
|
@ -24,8 +28,15 @@ from pytorch_lightning.trainer.progress import Progress
|
|||
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
||||
from pytorch_lightning.utilities.enums import _FaultTolerantMode
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.fetching import (
|
||||
AbstractDataFetcher,
|
||||
DataFetcher,
|
||||
DataLoaderIterDataFetcher,
|
||||
InterBatchParallelDataFetcher,
|
||||
)
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -57,6 +68,7 @@ class FitLoop(Loop[None]):
|
|||
|
||||
self._is_fresh_start_epoch: bool = True
|
||||
self._outputs: _EPOCH_OUTPUTS_TYPE = []
|
||||
self._data_fetcher: Optional[AbstractDataFetcher] = None
|
||||
|
||||
@property
|
||||
def global_step(self) -> int:
|
||||
|
@ -183,6 +195,8 @@ class FitLoop(Loop[None]):
|
|||
"""Calls the ``on_train_start`` hook."""
|
||||
# reset train dataloader and val dataloader
|
||||
self.trainer.reset_train_val_dataloaders(self.trainer.lightning_module)
|
||||
data_fetcher_cls = _select_data_fetcher(self.trainer)
|
||||
self._data_fetcher = data_fetcher_cls()
|
||||
|
||||
ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
|
||||
if not ft_enabled and self.restarting and self.trainer.num_training_batches not in (0, float("inf")):
|
||||
|
@ -203,6 +217,7 @@ class FitLoop(Loop[None]):
|
|||
|
||||
self._is_fresh_start_epoch = True
|
||||
self._results.to(device=self.trainer.lightning_module.device)
|
||||
|
||||
self.trainer._call_callback_hooks("on_train_start")
|
||||
self.trainer._call_lightning_module_hook("on_train_start")
|
||||
self.trainer._call_strategy_hook("on_train_start")
|
||||
|
@ -250,10 +265,11 @@ class FitLoop(Loop[None]):
|
|||
log.detail(f"{self.__class__.__name__}: advancing loop")
|
||||
assert self.trainer.train_dataloader is not None
|
||||
dataloader = self.trainer.strategy.process_dataloader(self.trainer.train_dataloader)
|
||||
data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader, 0)
|
||||
|
||||
self._data_fetcher.setup(
|
||||
dataloader, batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=0)
|
||||
)
|
||||
with self.trainer.profiler.profile("run_training_epoch"):
|
||||
self._outputs = self.epoch_loop.run(data_fetcher)
|
||||
self._outputs = self.epoch_loop.run(self._data_fetcher)
|
||||
|
||||
def on_advance_end(self) -> None:
|
||||
# inform logger the batch loop has finished
|
||||
|
@ -324,8 +340,26 @@ class FitLoop(Loop[None]):
|
|||
self.trainer.strategy.on_train_end()
|
||||
|
||||
def teardown(self) -> None:
|
||||
if self._data_fetcher is not None:
|
||||
self._data_fetcher.teardown()
|
||||
self._data_fetcher = None
|
||||
self.epoch_loop.teardown()
|
||||
|
||||
def _should_accumulate(self) -> bool:
|
||||
"""Whether the gradients should be accumulated."""
|
||||
return self.epoch_loop._should_accumulate()
|
||||
|
||||
|
||||
def _select_data_fetcher(trainer: "pl.Trainer") -> Type[AbstractDataFetcher]:
|
||||
training_step_fx = getattr(trainer.lightning_module, "training_step")
|
||||
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
|
||||
rank_zero_warn(
|
||||
"Found `dataloader_iter` argument in the `training_step`. Note that the support for "
|
||||
"this signature is experimental and the behavior is subject to change."
|
||||
)
|
||||
return DataLoaderIterDataFetcher
|
||||
elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
|
||||
if not isinstance(trainer.accelerator, GPUAccelerator):
|
||||
raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")
|
||||
return InterBatchParallelDataFetcher
|
||||
return DataFetcher
|
||||
|
|
|
@ -14,23 +14,18 @@
|
|||
import multiprocessing
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Collection, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Collection, List, Optional, Tuple, Union
|
||||
from weakref import proxy
|
||||
|
||||
from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.accelerators import GPUAccelerator
|
||||
from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler
|
||||
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.auto_restart import (
|
||||
_teardown_dataloader_get_iterators,
|
||||
_validate_fault_tolerant_automatic,
|
||||
)
|
||||
from pytorch_lightning.utilities.auto_restart import _validate_fault_tolerant_automatic
|
||||
from pytorch_lightning.utilities.data import (
|
||||
_auto_add_worker_init_fn,
|
||||
_is_dataloader_shuffled,
|
||||
|
@ -41,48 +36,22 @@ from pytorch_lightning.utilities.data import (
|
|||
)
|
||||
from pytorch_lightning.utilities.enums import _StrategyType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.fetching import (
|
||||
AbstractDataFetcher,
|
||||
DataFetcher,
|
||||
DataLoaderIterDataFetcher,
|
||||
InterBatchParallelDataFetcher,
|
||||
)
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
|
||||
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
|
||||
from pytorch_lightning.utilities.warnings import PossibleUserWarning
|
||||
|
||||
|
||||
class DataConnector:
|
||||
def __init__(
|
||||
self,
|
||||
trainer: "pl.Trainer",
|
||||
multiple_trainloader_mode: str = "max_size_cycle",
|
||||
train_data_fetcher: Optional[AbstractDataFetcher] = None,
|
||||
validate_data_fetcher: Optional[AbstractDataFetcher] = None,
|
||||
test_data_fetcher: Optional[AbstractDataFetcher] = None,
|
||||
):
|
||||
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
|
||||
self.trainer = trainer
|
||||
self.multiple_trainloader_mode = multiple_trainloader_mode
|
||||
|
||||
self.train_data_fetcher = train_data_fetcher
|
||||
self.validate_data_fetcher = validate_data_fetcher
|
||||
self.test_data_fetcher = test_data_fetcher
|
||||
self.sanity_check_data_fetcher: Optional[AbstractDataFetcher] = None
|
||||
|
||||
self._train_dataloader_source = _DataLoaderSource(None, "")
|
||||
self._val_dataloader_source = _DataLoaderSource(None, "")
|
||||
self._test_dataloader_source = _DataLoaderSource(None, "")
|
||||
self._predict_dataloader_source = _DataLoaderSource(None, "")
|
||||
|
||||
@property
|
||||
def evaluation_data_fetcher(self) -> Optional[AbstractDataFetcher]:
|
||||
if self.trainer.sanity_checking:
|
||||
return self.sanity_check_data_fetcher
|
||||
return self.test_data_fetcher if self.trainer.testing else self.validate_data_fetcher
|
||||
|
||||
@property
|
||||
def _should_reload_train_dl(self) -> bool:
|
||||
"""Check if train dataloader should be reloaded."""
|
||||
|
@ -126,33 +95,6 @@ class DataConnector:
|
|||
self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
|
||||
self.trainer._is_data_prepared = False
|
||||
|
||||
def _select_data_fetcher(self) -> AbstractDataFetcher:
|
||||
if not self.trainer.training:
|
||||
return DataFetcher()
|
||||
|
||||
training_step_fx = getattr(self.trainer.lightning_module, "training_step")
|
||||
if is_param_in_hook_signature(training_step_fx, "dataloader_iter", explicit=True):
|
||||
rank_zero_warn(
|
||||
"Found `dataloader_iter` argument in the `training_step`. Note that the support for "
|
||||
"this signature is experimental and the behavior is subject to change."
|
||||
)
|
||||
return DataLoaderIterDataFetcher()
|
||||
elif os.getenv("PL_INTER_BATCH_PARALLELISM", "0") == "1":
|
||||
if not isinstance(self.trainer.accelerator, GPUAccelerator):
|
||||
raise MisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")
|
||||
return InterBatchParallelDataFetcher()
|
||||
return DataFetcher()
|
||||
|
||||
def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int) -> Iterable:
|
||||
stage: str = self.trainer.state.stage.value
|
||||
data_fetcher = getattr(self, f"{stage}_data_fetcher", None) or self._select_data_fetcher()
|
||||
data_fetcher.setup(
|
||||
dataloader,
|
||||
batch_to_device=partial(self.trainer._call_strategy_hook, "batch_to_device", dataloader_idx=dataloader_idx),
|
||||
)
|
||||
setattr(self, f"{stage}_data_fetcher", data_fetcher)
|
||||
return data_fetcher
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
|
||||
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
|
||||
|
@ -559,21 +501,6 @@ class DataConnector:
|
|||
category=PossibleUserWarning,
|
||||
)
|
||||
|
||||
def teardown(self) -> None:
|
||||
if self.train_data_fetcher:
|
||||
self.train_data_fetcher.teardown()
|
||||
self.train_data_fetcher = None
|
||||
if self.validate_data_fetcher:
|
||||
self.validate_data_fetcher.teardown()
|
||||
self.validate_data_fetcher = None
|
||||
if self.test_data_fetcher:
|
||||
self.test_data_fetcher.teardown()
|
||||
self.test_data_fetcher = None
|
||||
if self.sanity_check_data_fetcher:
|
||||
self.sanity_check_data_fetcher.teardown()
|
||||
self.sanity_check_data_fetcher = None
|
||||
_teardown_dataloader_get_iterators()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DataLoaderSource:
|
||||
|
|
|
@ -1248,7 +1248,6 @@ class Trainer(
|
|||
Callback; those are handled by :meth:`_call_teardown_hook`."""
|
||||
self.strategy.post_dispatch(self)
|
||||
self.strategy.teardown()
|
||||
self._data_connector.teardown()
|
||||
loop = self._active_loop
|
||||
# loop should never be `None` here but it can because we don't know the trainer stage with `ddp_spawn`
|
||||
if loop is not None:
|
||||
|
|
|
@ -188,10 +188,10 @@ class AbstractDataFetcher(ABC):
|
|||
|
||||
def teardown(self) -> None:
|
||||
self.reset()
|
||||
if isinstance(self.dataloader, CombinedLoader):
|
||||
self.dataloader.reset()
|
||||
if isinstance(self.dataloader, DataLoader):
|
||||
CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader)
|
||||
if isinstance(self._dataloader, CombinedLoader):
|
||||
self._dataloader.reset()
|
||||
if isinstance(self._dataloader, DataLoader):
|
||||
CombinedLoader._shutdown_workers_and_reset_iterator(self._dataloader)
|
||||
self.dataloader_iter = None
|
||||
_teardown_dataloader_get_iterators()
|
||||
|
||||
|
|
|
@ -184,7 +184,7 @@ def test_trainer_num_prefetch_batches(tmpdir):
|
|||
self._check_inter_batch = check_inter_batch
|
||||
|
||||
def on_train_epoch_end(self, trainer, lightning_module):
|
||||
fetcher = trainer._data_connector.train_data_fetcher
|
||||
fetcher = trainer.fit_loop._data_fetcher
|
||||
assert isinstance(fetcher, InterBatchParallelDataFetcher if self._check_inter_batch else DataFetcher)
|
||||
assert fetcher.prefetch_batches == 1
|
||||
|
||||
|
@ -232,7 +232,7 @@ def test_fetching_dataloader_iter(automatic_optimization, tmpdir):
|
|||
|
||||
def training_step(self, dataloader_iter, batch_idx):
|
||||
assert self.count == batch_idx
|
||||
assert isinstance(self.trainer._data_connector.train_data_fetcher, DataLoaderIterDataFetcher)
|
||||
assert isinstance(self.trainer.fit_loop._data_fetcher, DataLoaderIterDataFetcher)
|
||||
# fetch 2 batches
|
||||
self.batches.append(next(dataloader_iter))
|
||||
self.batches.append(next(dataloader_iter))
|
||||
|
@ -255,7 +255,7 @@ def test_fetching_dataloader_iter(automatic_optimization, tmpdir):
|
|||
|
||||
def training_epoch_end(self, *_):
|
||||
assert self.trainer.fit_loop.epoch_loop.batch_progress.current.ready == 33
|
||||
assert self.trainer._data_connector.train_data_fetcher.fetched == 64
|
||||
assert self.trainer.fit_loop._data_fetcher.fetched == 64
|
||||
assert self.count == 64
|
||||
|
||||
model = TestModel(automatic_optimization=automatic_optimization)
|
||||
|
|
Loading…
Reference in New Issue