[Feat] 2/n Add Fault Tolerant Training to LightningFetcher (#8891)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
This commit is contained in:
parent
de22e40095
commit
19136ac847
|
@ -43,7 +43,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366))
|
||||
* Added `LightningDataFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890))
|
||||
* Added `SharedCycleIteratorState` to prevent infinite loop ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889))
|
||||
|
||||
* Added `CaptureMapDataset` for state management in map-style datasets ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
|
||||
* Added Fault Tolerant Training to LightningFetcher ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891))
|
||||
|
||||
- Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743))
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ from pytorch_lightning.trainer.supporters import CombinedLoader
|
|||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.auto_restart import (
|
||||
_sampler_metadata_collate,
|
||||
_capture_metadata_collate,
|
||||
CaptureIterableDataset,
|
||||
FastForwardSampler,
|
||||
)
|
||||
|
@ -529,5 +529,5 @@ class TrainerDataLoadingMixin(ABC):
|
|||
Wrap default collate function to retrive ``FastForwardSampler`` state dict when fault tolerant is enabled.
|
||||
"""
|
||||
dataloader.collate_fn = partial(
|
||||
_sampler_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn
|
||||
_capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn
|
||||
)
|
||||
|
|
|
@ -14,7 +14,9 @@
|
|||
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
from torch.utils.data import Dataset, get_worker_info, Sampler
|
||||
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset
|
||||
|
@ -49,9 +51,8 @@ class FastForwardSampler(Sampler):
|
|||
return getattr(self._sampler, key, None)
|
||||
|
||||
def setup(self, dataloader_batch_size: Optional[int] = None) -> None:
|
||||
"""
|
||||
Setup the ``FastForwardSampler``.
|
||||
This is required only when the provided dataset subclassed :class:`torch.utils.data.Dataset`.
|
||||
"""Setup the ``FastForwardSampler``. This is required only when the provided dataset subclassed
|
||||
:class:`torch.utils.data.Dataset`.
|
||||
"""
|
||||
self._dataloader_batch_size = dataloader_batch_size
|
||||
|
||||
|
@ -61,9 +62,10 @@ class FastForwardSampler(Sampler):
|
|||
return worker_info.id if worker_info else 0
|
||||
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
# the `state dict` was cached as workers were unavailable before
|
||||
# reload it now
|
||||
self._load_cached_state()
|
||||
self._current_iteration = 0
|
||||
# the `state dict` was cached as workers were unavailable before.
|
||||
if self._cached_state_dict is not None:
|
||||
self._load_non_random_state(self._cached_state_dict)
|
||||
|
||||
i = 0
|
||||
sampler_iter = iter(self._sampler)
|
||||
|
@ -72,6 +74,10 @@ class FastForwardSampler(Sampler):
|
|||
i += 1
|
||||
|
||||
# here: i == self._current_iteration
|
||||
if self._cached_state_dict is not None:
|
||||
self._cached_state_dict = None
|
||||
|
||||
# recreate iterator to be sure loading is reflected there as well
|
||||
while True:
|
||||
self._current_iteration += 1
|
||||
try:
|
||||
|
@ -80,6 +86,7 @@ class FastForwardSampler(Sampler):
|
|||
break
|
||||
|
||||
self._current_iteration = 0
|
||||
self._cached_state_dict = None
|
||||
self.restarting = False
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
@ -116,13 +123,111 @@ class FastForwardSampler(Sampler):
|
|||
|
||||
return current_iteration
|
||||
|
||||
def _load_cached_state(self):
|
||||
if self._cached_state_dict is None or self.worker_id not in self._cached_state_dict:
|
||||
return
|
||||
self._current_iteration = self._cached_state_dict[self.worker_id]["current_iteration"]
|
||||
# delete cached state, prevent reloading every time iter() is called
|
||||
def _load_non_random_state(self, state_dict: Dict[int, Dict[str, Any]]) -> None:
|
||||
self._current_iteration = state_dict[self.worker_id]["current_iteration"]
|
||||
|
||||
|
||||
@dataclass(frozen=True, unsafe_hash=True)
|
||||
class IteratorState:
|
||||
"""The state of an iterator in a single worker process."""
|
||||
|
||||
dataset_state: Dict[int, Any] = field(default_factory=dict)
|
||||
sampler_state: Dict[int, Any] = field(default_factory=dict)
|
||||
worker_id: int = 0
|
||||
num_workers: int = 0
|
||||
num_batches_fetched: int = 0
|
||||
name: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict) -> "IteratorState":
|
||||
return cls(**state_dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MergedIteratorState:
|
||||
"""This class is used to hold the current iterator state and lives on the iterator. It holds the current merged
|
||||
states from all worker processes. Once an iterator advances, it can store updates of the worker states in this
|
||||
merged iterator state."""
|
||||
|
||||
state: Union[Dict[Union[int, str], Union[Dict[str, IteratorState], IteratorState]]] = field(default_factory=dict)
|
||||
latest_worker_id: int = 0
|
||||
represent_map_dataset: Optional[bool] = None
|
||||
|
||||
def update(self, generator_name: Optional[str], new_state: IteratorState) -> None:
|
||||
# a map based dataset doesn't own a generator and therefore `generator_name` should be None.
|
||||
self.represent_map_dataset = generator_name is None
|
||||
if self.represent_map_dataset:
|
||||
state = self.state
|
||||
else:
|
||||
if generator_name not in self.state:
|
||||
self.state[generator_name] = {}
|
||||
state = self.state[generator_name]
|
||||
|
||||
latest_worker_id = new_state.worker_id
|
||||
state[latest_worker_id] = new_state
|
||||
self.latest_worker_id = latest_worker_id
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict) -> "MergedIteratorState":
|
||||
if state_dict["represent_map_dataset"]:
|
||||
state_dict["state"] = {
|
||||
worker_id: IteratorState.from_state_dict(state) for worker_id, state in state_dict["state"].items()
|
||||
}
|
||||
else:
|
||||
state_dict["state"] = {
|
||||
sampler_name: {
|
||||
worker_id: IteratorState.from_state_dict(state) for worker_id, state in worker_state.items()
|
||||
}
|
||||
for sampler_name, worker_state in state_dict["state"].items()
|
||||
}
|
||||
return cls(**state_dict)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.state)
|
||||
|
||||
|
||||
class CaptureMapDataset(Dataset):
|
||||
"""This class is used to capture the state from the map-based state dataset."""
|
||||
|
||||
def __init__(self, dataset: Dataset) -> None:
|
||||
self.dataset = dataset
|
||||
self._cached_state_dict = None
|
||||
|
||||
@property
|
||||
def worker_id(self) -> int:
|
||||
worker_info = get_worker_info()
|
||||
return worker_info.id if worker_info else 0
|
||||
|
||||
def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]:
|
||||
if self._cached_state_dict is not None:
|
||||
if self.worker_id in self._cached_state_dict:
|
||||
# TODO: reset random states
|
||||
pass
|
||||
self._cached_state_dict = None
|
||||
|
||||
data = self.dataset[item]
|
||||
state_dict = self._state_dict()
|
||||
return data, state_dict
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.dataset)
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None:
|
||||
# as workers aren't available, the ``state_dict``` is cached until workers are made available.
|
||||
state_dict = deepcopy(state_dict)
|
||||
|
||||
if num_workers > 0:
|
||||
# remap states to worker ids starting at 0
|
||||
next_worker_id = latest_worker_id + 1
|
||||
old_to_new_worker_id_map = [((next_worker_id + i) % num_workers, i) for i in range(num_workers)]
|
||||
state_dict = {
|
||||
new_id: state_dict[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state_dict
|
||||
}
|
||||
self._cached_state_dict = state_dict
|
||||
|
||||
def _state_dict(self) -> Dict[int, Dict[str, Any]]:
|
||||
return {self.worker_id: {"rng_states": {}}}
|
||||
|
||||
|
||||
class CaptureIterableDataset(IterableDataset):
|
||||
"""
|
||||
|
@ -136,8 +241,9 @@ class CaptureIterableDataset(IterableDataset):
|
|||
def __init__(self, dataset: IterableDataset) -> None:
|
||||
super().__init__()
|
||||
self.dataset = deepcopy(dataset)
|
||||
self._state_dict: Optional[Dict[int, Any]] = None
|
||||
self.samplers: Optional[Dict[str, FastForwardSampler]] = None
|
||||
self._state_dict: Optional[Dict[int, Any]] = None
|
||||
self._has_wrapped: bool = False
|
||||
|
||||
@property
|
||||
def sampler(self) -> Sampler:
|
||||
|
@ -188,22 +294,29 @@ class CaptureIterableDataset(IterableDataset):
|
|||
# if `CaptureIterableDataset` was available, the sampler should reload its own state.
|
||||
if self._state_dict is not None:
|
||||
sampler.load_state_dict(self._state_dict[generator_attr_name])
|
||||
|
||||
# store the samplers
|
||||
self.samplers[generator_attr_name] = sampler
|
||||
|
||||
# replace generator with the generator from the `FastForwardSampler`.
|
||||
dataset_dict[generator_attr_name] = iter(sampler)
|
||||
|
||||
def reset_on_epoch(self) -> None:
|
||||
self.reset_on_epoch()
|
||||
|
||||
def reset_on_epoch(self):
|
||||
self._state_dict = None
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
# create a generator from the wrapped Iterative Dataset
|
||||
# if the dataset contained samplers, they will be transformers into generators
|
||||
# if the dataset contained samplers, they will be transformed into generators
|
||||
self.iter_data = iter(self.dataset)
|
||||
|
||||
# wrap any generator associated to a Sampler into a `FastForwardSampler`.
|
||||
if isinstance(self.iter_data, Generator):
|
||||
raise MisconfigurationException(
|
||||
"PyTorch Lightning Fault-Tolerant feature does not support `__iter__` returning a generator."
|
||||
" Please use the `__next__` function to fetch the next batch and use a sampler for"
|
||||
" doing your iterations."
|
||||
)
|
||||
self._wrap_generator_samplers()
|
||||
return self
|
||||
|
||||
|
@ -214,7 +327,6 @@ class CaptureIterableDataset(IterableDataset):
|
|||
def store_samplers_state_dict(iterator: Iterator, sampler_state_dict: List) -> None:
|
||||
"""
|
||||
This function is used to store and update sampler state dict on its associated iterator.
|
||||
|
||||
In Lightning, as the iterator is wrapped into a prefetching function,
|
||||
we needed to introduce a cache to delay updating the ``sampler_state_dict``.
|
||||
"""
|
||||
|
@ -241,7 +353,7 @@ class CaptureIterableDataset(IterableDataset):
|
|||
|
||||
{
|
||||
"batch": ..., # data returned by DataLoader
|
||||
"__pl_samplers": {
|
||||
"__pl_restart_meta": {
|
||||
"sampler0": {
|
||||
0: {"current_iteration": ...},
|
||||
1: {"current_iteration": ...},
|
||||
|
@ -251,14 +363,14 @@ class CaptureIterableDataset(IterableDataset):
|
|||
}
|
||||
|
||||
Each sampler in the worker process tracks the current iteration. We return all of them to the main process
|
||||
as part of the sample and then a special collate function :func:`_sampler_metadata_collate`
|
||||
as part of the sample and then a special collate function :func:`_capture_metadata_collate`
|
||||
will extract the current iteration as part of the metadata returned by a custom batch.
|
||||
"""
|
||||
|
||||
def _sanitize(data: Mapping):
|
||||
out = []
|
||||
for k, v in data.items():
|
||||
if k == AutoRestartBatchKeys.PL_SAMPLERS:
|
||||
if k == AutoRestartBatchKeys.PL_RESTART_META:
|
||||
state_dicts.append(v)
|
||||
return data["data"]
|
||||
out.append((k, CaptureIterableDataset._sanitize_batch_from_sampler_state(v, state_dicts)))
|
||||
|
@ -376,20 +488,82 @@ def _find_current_worker(iterator: Iterator) -> Dict[str, Optional[int]]:
|
|||
return {"num_workers": num_workers, "previous_worker": previous_worker}
|
||||
|
||||
|
||||
def _sampler_metadata_collate(samples: List, dataset: Dataset, default_collate: Callable) -> Dict:
|
||||
"""
|
||||
A collate function that adds the state dict of all samplers used in the worker processes.
|
||||
|
||||
def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate: Callable) -> Dict:
|
||||
"""A collate function that adds the state dict of a :class:`CaptureIterableDataset` or :class:`CaptureMapDataset`
|
||||
used in the worker processes. This function gets executed within the worker processes.
|
||||
The structure will be:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
"data": ..., # data returned by Dataset
|
||||
"__pl_samplers": {"sampler_name0": state_dict0, "sampler_name1": state_dict1},
|
||||
"__pl_restart_meta": {"sampler_name0": state_dict0, "sampler_name1": state_dict1},
|
||||
}
|
||||
"""
|
||||
batch = default_collate(samples)
|
||||
if not isinstance(dataset, CaptureIterableDataset):
|
||||
if isinstance(dataset, CaptureIterableDataset):
|
||||
data = default_collate(samples)
|
||||
metadata = dataset.state_dict()
|
||||
|
||||
elif isinstance(dataset, CaptureMapDataset):
|
||||
samples, states = zip(*samples)
|
||||
data = default_collate(samples)
|
||||
metadata = states[-1]
|
||||
else:
|
||||
return default_collate(samples)
|
||||
|
||||
return {"data": data, AutoRestartBatchKeys.PL_RESTART_META: metadata}
|
||||
|
||||
|
||||
def patch_dataloader_iterator(
|
||||
dataloader: DataLoader, iterator: Iterator, prefetcher, num_batches_fetched: int = 0
|
||||
) -> None:
|
||||
assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset))
|
||||
|
||||
def _next_data_wrapper(fn, it, dl, num_batches_fetched) -> Callable:
|
||||
@wraps(fn)
|
||||
def wrapper():
|
||||
nonlocal num_batches_fetched
|
||||
nonlocal it
|
||||
nonlocal dl
|
||||
|
||||
dataset = dl.dataset
|
||||
combined_batch = fn()
|
||||
|
||||
batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_RESTART_META]
|
||||
num_batches_fetched += 1
|
||||
|
||||
if isinstance(dataset, CaptureIterableDataset):
|
||||
state = [
|
||||
IteratorState(
|
||||
num_workers=dataloader.num_workers,
|
||||
sampler_state=iterator_state,
|
||||
num_batches_fetched=num_batches_fetched,
|
||||
worker_id=list(iterator_state.keys())[0],
|
||||
name=sampler_iter_name,
|
||||
)
|
||||
for sampler_iter_name, iterator_state in state.items()
|
||||
]
|
||||
elif isinstance(dataset, CaptureMapDataset):
|
||||
ff_sampler = _find_fast_forward_samplers(dl)
|
||||
state = [
|
||||
IteratorState(
|
||||
num_workers=dataloader.num_workers,
|
||||
sampler_state=ff_sampler.state_dict(num_batches_fetched),
|
||||
dataset_state=state,
|
||||
worker_id=list(state.keys())[0],
|
||||
num_batches_fetched=num_batches_fetched,
|
||||
)
|
||||
]
|
||||
prefetcher._store_dataloader_iter_state(it, state)
|
||||
return batch
|
||||
return {"data": batch, AutoRestartBatchKeys.PL_SAMPLERS: dataset.state_dict()}
|
||||
|
||||
return wrapper
|
||||
|
||||
iterator._next_data = _next_data_wrapper(iterator._next_data, iterator, dataloader, num_batches_fetched)
|
||||
|
||||
|
||||
def _add_capture_metadata_collate(dataloader: DataLoader) -> None:
|
||||
"""Wrap default collate function to retrive captured dataset state dict when fault tolerant is enabled."""
|
||||
dataloader.collate_fn = partial(
|
||||
_capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn
|
||||
)
|
||||
|
|
|
@ -121,8 +121,6 @@ class GradClipAlgorithmType(LightningEnum):
|
|||
|
||||
|
||||
class AutoRestartBatchKeys(LightningEnum):
|
||||
"""
|
||||
Defines special dictionary keys used to track sampler progress with multiple workers.
|
||||
"""
|
||||
"""Defines special dictionary keys used to track captured dataset state with multiple workers."""
|
||||
|
||||
PL_SAMPLERS = "__pl_samplers"
|
||||
PL_RESTART_META = "__pl_restart_meta"
|
||||
|
|
|
@ -14,16 +14,25 @@
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable, Iterator
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Generator, List, Optional, Tuple
|
||||
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
|
||||
from pytorch_lightning.utilities.auto_restart import (
|
||||
_add_capture_metadata_collate,
|
||||
IteratorState,
|
||||
MergedIteratorState,
|
||||
patch_dataloader_iterator,
|
||||
)
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
|
||||
|
||||
|
||||
class AbstractDataFetcher(ABC):
|
||||
class AbstractFetcher(ABC):
|
||||
|
||||
"""
|
||||
This class is used to control batch fetching flow.
|
||||
|
@ -48,7 +57,6 @@ class AbstractDataFetcher(ABC):
|
|||
self.batches: List
|
||||
self.fetched: int
|
||||
self.done: bool
|
||||
self.has_raised: bool
|
||||
|
||||
self.reset()
|
||||
|
||||
|
@ -58,36 +66,72 @@ class AbstractDataFetcher(ABC):
|
|||
"The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``."
|
||||
)
|
||||
self.dataloader = dataloader
|
||||
if isinstance(dataloader, DataLoader) and not isinstance(dataloader.collate_fn, partial):
|
||||
_add_capture_metadata_collate(dataloader)
|
||||
|
||||
def add_batch(self, batch: Any) -> None:
|
||||
def add_batch(self, batch) -> None:
|
||||
self.batches.append(batch)
|
||||
|
||||
def fetch_batch(self) -> Any:
|
||||
return self.batches.pop(0)
|
||||
|
||||
def _apply_patch(self):
|
||||
def _apply_patch_fn(loader: DataLoader, iterator: Iterator):
|
||||
if isinstance(loader, CycleIterator):
|
||||
loader = loader.loader
|
||||
# cycle_iterator = iterator
|
||||
iterator = iterator._loader_iter
|
||||
|
||||
if isinstance(loader, DataLoader) and _fault_tolerant_enabled():
|
||||
loader._lightning_fetcher = self
|
||||
patch_dataloader_iterator(loader, iterator, self)
|
||||
|
||||
apply_to_collections(self.loaders, self.loader_iters, (Iterator, DataLoader), _apply_patch_fn)
|
||||
|
||||
def _store_dataloader_iter_state(
|
||||
self, dataloader_iter: Iterator, dataloader_iter_states: List[IteratorState]
|
||||
) -> None:
|
||||
if getattr(dataloader_iter, "cache_states", None) is None:
|
||||
dataloader_iter.cache_states = {}
|
||||
|
||||
if getattr(dataloader_iter, "state", None) is None:
|
||||
dataloader_iter.state = MergedIteratorState()
|
||||
|
||||
for iter_state in dataloader_iter_states:
|
||||
iter_name = iter_state.name
|
||||
if iter_name not in dataloader_iter.cache_states:
|
||||
dataloader_iter.cache_states[iter_name] = []
|
||||
dataloader_iter.cache_states[iter_name].append(iter_state)
|
||||
|
||||
if self.fetched >= self.prefetch_batches:
|
||||
for iter_state in dataloader_iter_states:
|
||||
if len(dataloader_iter.state):
|
||||
dataloader_iter.previous_state = deepcopy(dataloader_iter.state)
|
||||
iter_name = iter_state.name
|
||||
state = dataloader_iter.cache_states[iter_name].pop(0)
|
||||
dataloader_iter.state.update(iter_name, state)
|
||||
|
||||
@property
|
||||
def loaders(self) -> List[DataLoader]:
|
||||
if not self.dataloader:
|
||||
if self.dataloader is None:
|
||||
raise MisconfigurationException(
|
||||
"The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``."
|
||||
)
|
||||
if isinstance(self.dataloader, CombinedLoader):
|
||||
loaders = self.dataloader.loaders
|
||||
elif isinstance(self.dataloader, (tuple, list)):
|
||||
loaders = self.dataloader
|
||||
else:
|
||||
loaders = [self.dataloader]
|
||||
return loaders
|
||||
|
||||
@property
|
||||
def loader_iters(self) -> List[Iterator]:
|
||||
if not self.dataloader:
|
||||
if self.dataloader is None:
|
||||
raise MisconfigurationException(
|
||||
"The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``."
|
||||
)
|
||||
|
||||
if not self.dataloader_iter:
|
||||
raise MisconfigurationException("The dataloader_iter isn't available outside the __iter__ context.")
|
||||
if self.dataloader_iter is None:
|
||||
raise MisconfigurationException("The `dataloader_iter` isn't available outside the __iter__ context.")
|
||||
|
||||
if isinstance(self.dataloader, CombinedLoader):
|
||||
loader_iters = self.dataloader_iter.loader_iters
|
||||
|
@ -107,15 +151,17 @@ class AbstractDataFetcher(ABC):
|
|||
raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.")
|
||||
self.reset()
|
||||
self.dataloader_iter = iter(self.dataloader)
|
||||
self._apply_patch()
|
||||
return self.fetching_function()
|
||||
|
||||
def reset(self) -> None:
|
||||
self.batches: List = []
|
||||
self.dataloader: Optional[Iterable]
|
||||
self.fetched: int = 0
|
||||
self.done: bool = False
|
||||
|
||||
|
||||
class LightningDataFetcher(AbstractDataFetcher):
|
||||
class LightningDataFetcher(AbstractFetcher):
|
||||
|
||||
"""
|
||||
This class is used to control batch fetching flow.
|
||||
|
|
|
@ -30,9 +30,8 @@ from torch.utils.data.dataset import Dataset, IterableDataset
|
|||
|
||||
import tests.helpers.utils as tutils
|
||||
from pytorch_lightning import Callback, seed_everything, Trainer
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.auto_restart import (
|
||||
_add_capture_metadata_collate,
|
||||
_dataloader_load_state_dict,
|
||||
_dataloader_to_state_dict,
|
||||
CaptureIterableDataset,
|
||||
|
@ -263,7 +262,7 @@ def test_fast_forward_sampler_over_iterative_dataset(num_workers):
|
|||
dataset = CaptureIterableDataset(dataset)
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=generator)
|
||||
Trainer._add_sampler_metadata_collate(dataloader)
|
||||
_add_capture_metadata_collate(dataloader)
|
||||
|
||||
iter_dataloader = iter(dataloader)
|
||||
batches = []
|
||||
|
@ -286,7 +285,7 @@ def test_fast_forward_sampler_over_iterative_dataset(num_workers):
|
|||
dataset = CaptureIterableDataset(dataset)
|
||||
dataset.load_state_dict(state_dict)
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=generator)
|
||||
Trainer._add_sampler_metadata_collate(dataloader)
|
||||
_add_capture_metadata_collate(dataloader)
|
||||
|
||||
iter_dataloader = iter(dataloader)
|
||||
batches_restart = []
|
||||
|
@ -541,7 +540,7 @@ def _test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(ra
|
|||
)
|
||||
dataset = CaptureIterableDataset(dataset)
|
||||
dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator)
|
||||
Trainer._add_sampler_metadata_collate(dataloader)
|
||||
_add_capture_metadata_collate(dataloader)
|
||||
|
||||
epoch_results = []
|
||||
for _ in range(2):
|
||||
|
@ -564,8 +563,8 @@ def _test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(ra
|
|||
assert torch.equal(
|
||||
epoch_results[0][0]["data"]["selected_indexes"], epoch_results[0][1]["data"]["selected_indexes"]
|
||||
)
|
||||
assert 0 in epoch_results[0][2][AutoRestartBatchKeys.PL_SAMPLERS]["iter_sampler"] # worker id 0
|
||||
assert 1 in epoch_results[0][3][AutoRestartBatchKeys.PL_SAMPLERS]["iter_sampler"] # worker id 1
|
||||
assert 0 in epoch_results[0][2][AutoRestartBatchKeys.PL_RESTART_META]["iter_sampler"] # worker id 0
|
||||
assert 1 in epoch_results[0][3][AutoRestartBatchKeys.PL_RESTART_META]["iter_sampler"] # worker id 1
|
||||
assert not torch.equal(epoch_results[0][2]["data"][0], epoch_results[0][3]["data"][0])
|
||||
else:
|
||||
first_task_metadata = all_gather(epoch_results[0][0]["data"]["task_length"], worldsize)
|
||||
|
@ -602,7 +601,7 @@ def _test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(ra
|
|||
dataset = CaptureIterableDataset(dataset)
|
||||
dataset.load_state_dict(state_dict)
|
||||
dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator)
|
||||
Trainer._add_sampler_metadata_collate(dataloader)
|
||||
_add_capture_metadata_collate(dataloader)
|
||||
|
||||
epoch_results_restart = []
|
||||
for _ in range(2):
|
||||
|
@ -661,130 +660,6 @@ def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", w
|
|||
return dataset
|
||||
|
||||
|
||||
def create_dataloader():
|
||||
dataset = range(50)
|
||||
num_workers = 2
|
||||
batch_size = 8
|
||||
sampler = FastForwardSampler(SequentialSampler(dataset))
|
||||
sampler.setup(batch_size)
|
||||
|
||||
dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size)
|
||||
dataloader.fast_forward_sampler = sampler
|
||||
|
||||
loader_dict = {
|
||||
"a": [DataLoader(create_iterable_dataset(3, num_workers), num_workers=num_workers, batch_size=3), dataloader],
|
||||
"b": DataLoader(
|
||||
create_iterable_dataset(2, num_workers=1, attr_name="custom_sampler"), num_workers=0, batch_size=2
|
||||
),
|
||||
}
|
||||
apply_to_collection(loader_dict, DataLoader, Trainer._add_sampler_metadata_collate)
|
||||
return CombinedLoader(loader_dict)
|
||||
|
||||
|
||||
# Lightning will wrap the iterator within a prefect function as follow.
|
||||
def prefetch_iterator(iterable: Iterable):
|
||||
it = iter(iterable)
|
||||
|
||||
try:
|
||||
# the iterator may be empty from the beginning
|
||||
last = next(it)
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
for val in it:
|
||||
# yield last and has next
|
||||
yield last, False, it
|
||||
last = val
|
||||
# yield last, no longer has next
|
||||
yield last, True, it
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 15 sec and should be skipped in Azure CI")
|
||||
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
|
||||
@RunIf(min_torch="1.7.0")
|
||||
def test_combined_dataloader_state_dict_and_reload():
|
||||
"""
|
||||
This test makes sure the CombinedLoader used in the condition of Lightning properly
|
||||
capture its children DataLoader states.
|
||||
"""
|
||||
dataloader = create_dataloader()
|
||||
|
||||
iter_dataloader = iter(prefetch_iterator(dataloader))
|
||||
num_batches_processed = 4
|
||||
for idx in range(1, num_batches_processed):
|
||||
_, _, prefetched_iterator = next(iter_dataloader)
|
||||
|
||||
loader_iters = prefetched_iterator._loader_iters
|
||||
|
||||
# when dealing with IterativeDataset,
|
||||
# the sampler state dict will be attached directly onto the iterator to simplify collection.
|
||||
|
||||
if idx == 1:
|
||||
assert loader_iters["a"][0]._sampler_state_dict == [{"iter_sampler": {0: {"current_iteration": 3}}}]
|
||||
assert loader_iters["a"][1]._sampler_state_dict == []
|
||||
assert loader_iters["b"]._sampler_state_dict == [{"custom_sampler": {0: {"current_iteration": 2}}}]
|
||||
elif idx == 2:
|
||||
assert loader_iters["a"][0]._sampler_state_dict == [
|
||||
{"iter_sampler": {0: dict(current_iteration=3), 1: dict(current_iteration=3)}}
|
||||
]
|
||||
assert loader_iters["a"][1]._sampler_state_dict == []
|
||||
assert loader_iters["b"]._sampler_state_dict == [{"custom_sampler": {0: {"current_iteration": 4}}}]
|
||||
else:
|
||||
assert loader_iters["a"][0]._sampler_state_dict == [
|
||||
{"iter_sampler": {0: dict(current_iteration=6), 1: dict(current_iteration=3)}}
|
||||
]
|
||||
assert loader_iters["a"][1]._sampler_state_dict == []
|
||||
assert loader_iters["b"]._sampler_state_dict == [{"custom_sampler": {0: {"current_iteration": 6}}}]
|
||||
|
||||
state_dict = dataloader.state_dict(num_batches_processed=3)
|
||||
|
||||
expected = {
|
||||
"b": {"num_workers": 0, "previous_worker": None, "custom_sampler": {0: dict(current_iteration=6)}},
|
||||
"a": [
|
||||
{
|
||||
"num_workers": 2,
|
||||
"previous_worker": 1,
|
||||
"iter_sampler": {0: dict(current_iteration=6), 1: dict(current_iteration=3)},
|
||||
},
|
||||
{"num_workers": 0, "previous_worker": None, 0: dict(current_iteration=24)},
|
||||
],
|
||||
}
|
||||
assert state_dict == expected
|
||||
|
||||
dataloader = create_dataloader()
|
||||
apply_to_collection(dataloader, DataLoader, Trainer._add_sampler_metadata_collate)
|
||||
dataloader.load_state_dict(state_dict)
|
||||
|
||||
iter_dataloader = iter(prefetch_iterator(dataloader))
|
||||
_, _, prefetched_iterator = next(iter_dataloader)
|
||||
|
||||
loader_iters = prefetched_iterator._loader_iters
|
||||
|
||||
assert loader_iters["a"][0]._sampler_state_dict == [
|
||||
{"num_workers": 2, "iter_sampler": {0: dict(current_iteration=6), 1: dict(current_iteration=6)}}
|
||||
]
|
||||
assert loader_iters["a"][1]._sampler_state_dict == []
|
||||
assert loader_iters["b"]._sampler_state_dict == [
|
||||
{"num_workers": 0, "custom_sampler": {0: dict(current_iteration=8)}}
|
||||
]
|
||||
|
||||
state_dict = dataloader.state_dict(num_batches_processed=4)
|
||||
|
||||
expected = {
|
||||
"a": [
|
||||
{
|
||||
"num_workers": 2,
|
||||
"previous_worker": 0,
|
||||
"iter_sampler": {0: dict(current_iteration=6), 1: dict(current_iteration=6)},
|
||||
},
|
||||
{"num_workers": 0, "previous_worker": None, 0: dict(current_iteration=32)},
|
||||
],
|
||||
"b": {"num_workers": 0, "previous_worker": None, "custom_sampler": {0: dict(current_iteration=8)}},
|
||||
}
|
||||
|
||||
assert state_dict == expected
|
||||
|
||||
|
||||
def test_dataloader_to_state_dict_and_reload():
|
||||
"""
|
||||
Note: Those utilities are used only with DataLoader wrapping a ``mapping`` based dataset.
|
||||
|
@ -804,7 +679,7 @@ def test_dataloader_to_state_dict_and_reload():
|
|||
_ = next(iter_dataloader)
|
||||
|
||||
state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader)
|
||||
assert state_dict == {"num_workers": 0, "previous_worker": None, 0: {"current_iteration": 16}}
|
||||
assert state_dict[0]["current_iteration"] == 16
|
||||
|
||||
dataloader = create_dataloader()
|
||||
dataloader = _dataloader_load_state_dict(dataloader, state_dict)
|
||||
|
@ -812,7 +687,7 @@ def test_dataloader_to_state_dict_and_reload():
|
|||
_ = next(iter_dataloader)
|
||||
|
||||
state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader)
|
||||
assert state_dict == {"num_workers": 0, "previous_worker": None, 0: {"current_iteration": 24}}
|
||||
assert state_dict[0]["current_iteration"] == 24
|
||||
|
||||
|
||||
@RunIf(min_torch="1.7.0")
|
||||
|
|
|
@ -75,13 +75,14 @@ def test_misconfiguration_error():
|
|||
|
||||
fetcher = LightningDataFetcher()
|
||||
with pytest.raises(
|
||||
MisconfigurationException, match="The `DataFetcher` should be setup with an instance of a PyTorch"
|
||||
MisconfigurationException,
|
||||
match="The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``.",
|
||||
):
|
||||
fetcher.setup(range(10))
|
||||
|
||||
fetcher = LightningDataFetcher()
|
||||
with pytest.raises(
|
||||
MisconfigurationException, match="The dataloader_iter isn't available outside the __iter__ context."
|
||||
MisconfigurationException, match="The `dataloader_iter` isn't available outside the __iter__ context."
|
||||
):
|
||||
loader = DataLoader(range(10))
|
||||
fetcher.setup(loader)
|
||||
|
|
Loading…
Reference in New Issue