# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod from collections.abc import Iterable, Iterator from copy import deepcopy from typing import Any, Callable, List, Optional, Tuple import torch from torch.utils.data.dataloader import DataLoader from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, _patch_dataloader_get_iterators, _teardown_dataloader_get_iterators, IteratorState, MergedIteratorState, patch_dataloader_iterator, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training class AbstractDataFetcher(ABC): """This base class should be used to implement a fault tolerant ``DataFetcher``. It is required to override the ``fetching_function`` with fetching logic. Example:: class SimpleDataFetcher(AbstractDataFetcher): def fetching_function(self): while True: try: return next(self.dataloader_iter), False except StopIteration: return None, True """ @abstractmethod def fetching_function(self) -> Any: """Override with your own fetching logic.""" @abstractmethod def prefetching(self) -> None: """Override with your own pre-fetching logic.""" def on_fetch_start(self) -> Any: """Hook to override to handle the logic before fetching a batch.""" def on_fetch_end(self, batch: Any, start_output: Any) -> None: """Hook to extend which handles the logic after fetching a batch.""" def wait(self) -> None: """Hook to override to indicate the `DataFetcher` to wait for an event.""" def __init__(self, prefetch_batches: int = 0) -> None: if prefetch_batches < 0: raise MisconfigurationException("`prefetch_batches` should at least be 0.") self.prefetch_batches = prefetch_batches self._dataloader: Optional[Iterable] = None self.dataloader_iter: Optional[Iterator] = None self.fetched: int = 0 self.done: bool = False def setup(self, dataloader: Iterable, **kwargs: Any) -> None: self._add_capture_metadata_collate(dataloader) self._dataloader = dataloader @property def dataloader(self) -> Iterable: if self._dataloader is None: raise MisconfigurationException( f"`{self.__class__.__name__}` should have been `setup` with a dataloader iterable." ) return self._dataloader @staticmethod def _add_capture_metadata_collate(dataloader: Iterable) -> None: if not isinstance(dataloader, (DataLoader, CombinedLoader)): return if isinstance(dataloader, CombinedLoader): dataloader = dataloader.loaders apply_to_collection(dataloader, DataLoader, _add_capture_metadata_collate) def _apply_patch(self) -> None: def _apply_patch_fn(loader: DataLoader, iterator: Iterator) -> None: if isinstance(loader, CycleIterator): loader = loader.loader # cycle_iterator = iterator iterator = iterator._loader_iter if isinstance(loader, DataLoader) and _fault_tolerant_training(): loader._lightning_fetcher = self patch_dataloader_iterator(loader, iterator, self) apply_to_collections(self.loaders, self.loader_iters, (Iterator, DataLoader), _apply_patch_fn) def _store_dataloader_iter_state( self, dataloader_iter: Iterator, dataloader_iter_states: List[IteratorState] ) -> None: if getattr(dataloader_iter, "cache_states", None) is None: dataloader_iter.cache_states = {} if getattr(dataloader_iter, "state", None) is None: dataloader_iter.state = MergedIteratorState() for iter_state in dataloader_iter_states: iter_name = iter_state.name if iter_name not in dataloader_iter.cache_states: dataloader_iter.cache_states[iter_name] = [] dataloader_iter.cache_states[iter_name].append(iter_state) if self.fetched >= self.prefetch_batches: for iter_state in dataloader_iter_states: if len(dataloader_iter.state): dataloader_iter.previous_state = deepcopy(dataloader_iter.state) iter_name = iter_state.name state = dataloader_iter.cache_states[iter_name].pop(0) dataloader_iter.state.update(iter_name, state) @property def loaders(self) -> List[DataLoader]: if isinstance(self.dataloader, CombinedLoader): loaders = self.dataloader.loaders else: loaders = [self.dataloader] return loaders @property def loader_iters(self) -> List[Iterator]: if self.dataloader_iter is None: raise MisconfigurationException("The `dataloader_iter` isn't available outside the __iter__ context.") if isinstance(self.dataloader, CombinedLoader): loader_iters = self.dataloader_iter.loader_iters else: loader_iters = [self.dataloader_iter] return loader_iters @property def state(self) -> List[MergedIteratorState]: def collect_state(iterator: Iterator) -> MergedIteratorState: return iterator.state return apply_to_collection(self.loader_iters, Iterator, collect_state) def _attach_data_fetcher(self) -> None: def _attach_data_fetcher_fn(loader: DataLoader) -> None: if isinstance(loader, CycleIterator): loader = loader.loader if isinstance(loader, DataLoader) and _fault_tolerant_training(): loader._lightning_fetcher = self apply_to_collection(self.loaders, (DataLoader, CycleIterator), _attach_data_fetcher_fn) def __iter__(self) -> "AbstractDataFetcher": self.reset() self._attach_data_fetcher() _patch_dataloader_get_iterators() self.dataloader_iter = iter(self.dataloader) self._apply_patch() self.prefetching() return self def __next__(self) -> Any: return self.fetching_function() def reset(self) -> None: self.fetched = 0 self.done = False def teardown(self) -> None: self.reset() if isinstance(self._dataloader, CombinedLoader): self._dataloader.reset() if isinstance(self._dataloader, DataLoader): CombinedLoader._shutdown_workers_and_reset_iterator(self._dataloader) self.dataloader_iter = None _teardown_dataloader_get_iterators() def _no_op_batch_to_device(batch: Any) -> Any: return batch class DataFetcher(AbstractDataFetcher): """This class is used to control batch fetching flow. Args: prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track whether a batch is the last one (available with :attr:`self.done`). store_on_device: Whether to store the pre-fetched batches on device. """ def __init__(self, prefetch_batches: int = 1, store_on_device: bool = True) -> None: super().__init__(prefetch_batches=prefetch_batches) self.store_on_device = store_on_device self.batch_to_device: Callable[[Any], Any] = _no_op_batch_to_device self.batches: List[Any] = [] def setup( # type: ignore[override] self, dataloader: Iterable, batch_to_device: Optional[Callable[[Any], Any]] = None ) -> None: super().setup(dataloader) if batch_to_device is not None: self.batch_to_device = batch_to_device def on_fetch_end(self, batch: Any, start_output: Any) -> None: """Hook to extend which handles the logic after fetching a batch.""" self.batches.append(batch) def prefetching(self) -> None: iterator = self.dataloader_iter assert iterator is not None for _ in range(self.prefetch_batches): try: self._fetch_next_batch(iterator) except StopIteration: break def fetching_function(self) -> Any: assert self.dataloader_iter is not None if self.batches: # there are pre-fetched batches already from a previous `prefetching` call. # consume one batch = self.batches.pop(0) try: # refill the consumed batch self._fetch_next_batch(self.dataloader_iter) except StopIteration: # no more batches to fetch. we are done only if all pre-fetched batches were returned self.done = not self.batches elif not self.done: # this will run only when no pre-fetching was done. try: self._fetch_next_batch(self.dataloader_iter) # consume the batch we just fetched batch = self.batches.pop(0) except StopIteration as e: self.done = True raise e else: # the iterator is empty raise StopIteration self.wait() return self.move_to_device(batch) def _fetch_next_batch(self, iterator: Iterator) -> None: start_output = self.on_fetch_start() batch = next(iterator) self.fetched += 1 self.on_fetch_end(batch, start_output) def move_to_device(self, batch: Any) -> Any: if self.store_on_device: batch = self.batch_to_device(batch) return batch def reset(self) -> None: super().reset() self.batches = [] class InterBatchParallelDataFetcher(DataFetcher): """This class implements inter-batch parallelism, which aims at hiding the latency of host-to-device copy of input batches behind computationally intensive operations. code-block:: Without parallelization: batch 0: [HtoD][forward][backward] batch 1: [HtoD][forward][backward] batch 2: [HtoD][forward][backward] With parallelization, the latency of HtoD copy can be hidden: batch 0: [HtoD][forward][backward] batch 1: [HtoD] [forward][backward] batch 2: [HtoD] [forward][backward] """ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.cuda_stream = torch.cuda.Stream() self.events: List[torch.cuda.Event] = [] def move_to_device(self, batch: Any) -> Any: with torch.cuda.stream(self.cuda_stream): return super().move_to_device(batch) def on_fetch_start(self) -> "torch.cuda.Event": # create a cuda event used to record the async stream of data to device. return torch.cuda.Event() def on_fetch_end(self, batch: Any, event: torch.cuda.Event) -> None: self.batches.append(batch) event.record() self.events.append(event) def wait(self) -> None: # pop first event from the queue and wait for the batch to be available on device. event = self.events.pop(0) event.wait() class StepFuncDataLoaderIter(Iterator): """This class is a wrapper to keep track of dataloader iterator fetching event while left entirely to user control.""" def __init__(self, iterator: Iterator, data_fetcher: AbstractDataFetcher) -> None: self.iterator = iterator self.data_fetcher = data_fetcher def __next__(self) -> Any: try: data = next(self.iterator) self.data_fetcher.fetched += 1 return data except StopIteration as e: self.data_fetcher.done = True raise e class DataLoaderIterDataFetcher(AbstractDataFetcher): """This class is used to return directly the `dataloader_iter` to the ``LightningModule`` training_step for users to implement their own pre-fetching logic. This feature can be activated as follows: Example:: Class MyModel(LightningModule): def __init__(self): self.automatic_optimization = False def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None: # it is the user responsibility to fetch and move the batch to the right device. batch = next(dataloader_iter) batch = batch.to(self.device) ... """ def __init__(self) -> None: super().__init__() self.store_on_device = False def prefetching(self) -> None: iterator = self.dataloader_iter assert iterator is not None self.iterator = iter(StepFuncDataLoaderIter(iterator, self)) def fetching_function(self) -> Tuple[int, Iterator]: if not self.done: return self.fetched, self.iterator raise StopIteration