diff --git a/CHANGELOG.md b/CHANGELOG.md index e0afbb8bad..664477da03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,7 +43,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366)) * Added `LightningDataFetcher` to control fetching flow ([#8890](https://github.com/PyTorchLightning/pytorch-lightning/pull/8890)) * Added `SharedCycleIteratorState` to prevent infinite loop ([#8889](https://github.com/PyTorchLightning/pytorch-lightning/pull/8889)) - + * Added `CaptureMapDataset` for state management in map-style datasets ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891)) + * Added Fault Tolerant Training to LightningFetcher ([#8891](https://github.com/PyTorchLightning/pytorch-lightning/pull/8891)) - Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743)) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 10fd5bb390..e63138a74d 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -32,7 +32,7 @@ from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.auto_restart import ( - _sampler_metadata_collate, + _capture_metadata_collate, CaptureIterableDataset, FastForwardSampler, ) @@ -529,5 +529,5 @@ class TrainerDataLoadingMixin(ABC): Wrap default collate function to retrive ``FastForwardSampler`` state dict when fault tolerant is enabled. """ dataloader.collate_fn = partial( - _sampler_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn + _capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn ) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 464823038a..c60c42bdc9 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -14,7 +14,9 @@ from collections.abc import Mapping from copy import deepcopy -from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Union +from dataclasses import dataclass, field +from functools import partial, wraps +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union from torch.utils.data import Dataset, get_worker_info, Sampler from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset @@ -49,9 +51,8 @@ class FastForwardSampler(Sampler): return getattr(self._sampler, key, None) def setup(self, dataloader_batch_size: Optional[int] = None) -> None: - """ - Setup the ``FastForwardSampler``. - This is required only when the provided dataset subclassed :class:`torch.utils.data.Dataset`. + """Setup the ``FastForwardSampler``. This is required only when the provided dataset subclassed + :class:`torch.utils.data.Dataset`. """ self._dataloader_batch_size = dataloader_batch_size @@ -61,9 +62,10 @@ class FastForwardSampler(Sampler): return worker_info.id if worker_info else 0 def __iter__(self) -> Iterator[Any]: - # the `state dict` was cached as workers were unavailable before - # reload it now - self._load_cached_state() + self._current_iteration = 0 + # the `state dict` was cached as workers were unavailable before. + if self._cached_state_dict is not None: + self._load_non_random_state(self._cached_state_dict) i = 0 sampler_iter = iter(self._sampler) @@ -72,6 +74,10 @@ class FastForwardSampler(Sampler): i += 1 # here: i == self._current_iteration + if self._cached_state_dict is not None: + self._cached_state_dict = None + + # recreate iterator to be sure loading is reflected there as well while True: self._current_iteration += 1 try: @@ -80,6 +86,7 @@ class FastForwardSampler(Sampler): break self._current_iteration = 0 + self._cached_state_dict = None self.restarting = False def __len__(self) -> int: @@ -116,13 +123,111 @@ class FastForwardSampler(Sampler): return current_iteration - def _load_cached_state(self): - if self._cached_state_dict is None or self.worker_id not in self._cached_state_dict: - return - self._current_iteration = self._cached_state_dict[self.worker_id]["current_iteration"] - # delete cached state, prevent reloading every time iter() is called + def _load_non_random_state(self, state_dict: Dict[int, Dict[str, Any]]) -> None: + self._current_iteration = state_dict[self.worker_id]["current_iteration"] + + +@dataclass(frozen=True, unsafe_hash=True) +class IteratorState: + """The state of an iterator in a single worker process.""" + + dataset_state: Dict[int, Any] = field(default_factory=dict) + sampler_state: Dict[int, Any] = field(default_factory=dict) + worker_id: int = 0 + num_workers: int = 0 + num_batches_fetched: int = 0 + name: Optional[str] = None + + @classmethod + def from_state_dict(cls, state_dict) -> "IteratorState": + return cls(**state_dict) + + +@dataclass +class MergedIteratorState: + """This class is used to hold the current iterator state and lives on the iterator. It holds the current merged + states from all worker processes. Once an iterator advances, it can store updates of the worker states in this + merged iterator state.""" + + state: Union[Dict[Union[int, str], Union[Dict[str, IteratorState], IteratorState]]] = field(default_factory=dict) + latest_worker_id: int = 0 + represent_map_dataset: Optional[bool] = None + + def update(self, generator_name: Optional[str], new_state: IteratorState) -> None: + # a map based dataset doesn't own a generator and therefore `generator_name` should be None. + self.represent_map_dataset = generator_name is None + if self.represent_map_dataset: + state = self.state + else: + if generator_name not in self.state: + self.state[generator_name] = {} + state = self.state[generator_name] + + latest_worker_id = new_state.worker_id + state[latest_worker_id] = new_state + self.latest_worker_id = latest_worker_id + + @classmethod + def from_state_dict(cls, state_dict) -> "MergedIteratorState": + if state_dict["represent_map_dataset"]: + state_dict["state"] = { + worker_id: IteratorState.from_state_dict(state) for worker_id, state in state_dict["state"].items() + } + else: + state_dict["state"] = { + sampler_name: { + worker_id: IteratorState.from_state_dict(state) for worker_id, state in worker_state.items() + } + for sampler_name, worker_state in state_dict["state"].items() + } + return cls(**state_dict) + + def __len__(self) -> int: + return len(self.state) + + +class CaptureMapDataset(Dataset): + """This class is used to capture the state from the map-based state dataset.""" + + def __init__(self, dataset: Dataset) -> None: + self.dataset = dataset self._cached_state_dict = None + @property + def worker_id(self) -> int: + worker_info = get_worker_info() + return worker_info.id if worker_info else 0 + + def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]: + if self._cached_state_dict is not None: + if self.worker_id in self._cached_state_dict: + # TODO: reset random states + pass + self._cached_state_dict = None + + data = self.dataset[item] + state_dict = self._state_dict() + return data, state_dict + + def __len__(self) -> int: + return len(self.dataset) + + def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None: + # as workers aren't available, the ``state_dict``` is cached until workers are made available. + state_dict = deepcopy(state_dict) + + if num_workers > 0: + # remap states to worker ids starting at 0 + next_worker_id = latest_worker_id + 1 + old_to_new_worker_id_map = [((next_worker_id + i) % num_workers, i) for i in range(num_workers)] + state_dict = { + new_id: state_dict[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state_dict + } + self._cached_state_dict = state_dict + + def _state_dict(self) -> Dict[int, Dict[str, Any]]: + return {self.worker_id: {"rng_states": {}}} + class CaptureIterableDataset(IterableDataset): """ @@ -136,8 +241,9 @@ class CaptureIterableDataset(IterableDataset): def __init__(self, dataset: IterableDataset) -> None: super().__init__() self.dataset = deepcopy(dataset) - self._state_dict: Optional[Dict[int, Any]] = None self.samplers: Optional[Dict[str, FastForwardSampler]] = None + self._state_dict: Optional[Dict[int, Any]] = None + self._has_wrapped: bool = False @property def sampler(self) -> Sampler: @@ -188,22 +294,29 @@ class CaptureIterableDataset(IterableDataset): # if `CaptureIterableDataset` was available, the sampler should reload its own state. if self._state_dict is not None: sampler.load_state_dict(self._state_dict[generator_attr_name]) - # store the samplers self.samplers[generator_attr_name] = sampler # replace generator with the generator from the `FastForwardSampler`. dataset_dict[generator_attr_name] = iter(sampler) - def reset_on_epoch(self) -> None: + self.reset_on_epoch() + + def reset_on_epoch(self): self._state_dict = None def __iter__(self) -> Iterator: # create a generator from the wrapped Iterative Dataset - # if the dataset contained samplers, they will be transformers into generators + # if the dataset contained samplers, they will be transformed into generators self.iter_data = iter(self.dataset) # wrap any generator associated to a Sampler into a `FastForwardSampler`. + if isinstance(self.iter_data, Generator): + raise MisconfigurationException( + "PyTorch Lightning Fault-Tolerant feature does not support `__iter__` returning a generator." + " Please use the `__next__` function to fetch the next batch and use a sampler for" + " doing your iterations." + ) self._wrap_generator_samplers() return self @@ -214,7 +327,6 @@ class CaptureIterableDataset(IterableDataset): def store_samplers_state_dict(iterator: Iterator, sampler_state_dict: List) -> None: """ This function is used to store and update sampler state dict on its associated iterator. - In Lightning, as the iterator is wrapped into a prefetching function, we needed to introduce a cache to delay updating the ``sampler_state_dict``. """ @@ -241,7 +353,7 @@ class CaptureIterableDataset(IterableDataset): { "batch": ..., # data returned by DataLoader - "__pl_samplers": { + "__pl_restart_meta": { "sampler0": { 0: {"current_iteration": ...}, 1: {"current_iteration": ...}, @@ -251,14 +363,14 @@ class CaptureIterableDataset(IterableDataset): } Each sampler in the worker process tracks the current iteration. We return all of them to the main process - as part of the sample and then a special collate function :func:`_sampler_metadata_collate` + as part of the sample and then a special collate function :func:`_capture_metadata_collate` will extract the current iteration as part of the metadata returned by a custom batch. """ def _sanitize(data: Mapping): out = [] for k, v in data.items(): - if k == AutoRestartBatchKeys.PL_SAMPLERS: + if k == AutoRestartBatchKeys.PL_RESTART_META: state_dicts.append(v) return data["data"] out.append((k, CaptureIterableDataset._sanitize_batch_from_sampler_state(v, state_dicts))) @@ -376,20 +488,82 @@ def _find_current_worker(iterator: Iterator) -> Dict[str, Optional[int]]: return {"num_workers": num_workers, "previous_worker": previous_worker} -def _sampler_metadata_collate(samples: List, dataset: Dataset, default_collate: Callable) -> Dict: - """ - A collate function that adds the state dict of all samplers used in the worker processes. - +def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate: Callable) -> Dict: + """A collate function that adds the state dict of a :class:`CaptureIterableDataset` or :class:`CaptureMapDataset` + used in the worker processes. This function gets executed within the worker processes. The structure will be: .. code-block:: python { "data": ..., # data returned by Dataset - "__pl_samplers": {"sampler_name0": state_dict0, "sampler_name1": state_dict1}, + "__pl_restart_meta": {"sampler_name0": state_dict0, "sampler_name1": state_dict1}, } """ - batch = default_collate(samples) - if not isinstance(dataset, CaptureIterableDataset): - return batch - return {"data": batch, AutoRestartBatchKeys.PL_SAMPLERS: dataset.state_dict()} + if isinstance(dataset, CaptureIterableDataset): + data = default_collate(samples) + metadata = dataset.state_dict() + + elif isinstance(dataset, CaptureMapDataset): + samples, states = zip(*samples) + data = default_collate(samples) + metadata = states[-1] + else: + return default_collate(samples) + + return {"data": data, AutoRestartBatchKeys.PL_RESTART_META: metadata} + + +def patch_dataloader_iterator( + dataloader: DataLoader, iterator: Iterator, prefetcher, num_batches_fetched: int = 0 +) -> None: + assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset)) + + def _next_data_wrapper(fn, it, dl, num_batches_fetched) -> Callable: + @wraps(fn) + def wrapper(): + nonlocal num_batches_fetched + nonlocal it + nonlocal dl + + dataset = dl.dataset + combined_batch = fn() + + batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_RESTART_META] + num_batches_fetched += 1 + + if isinstance(dataset, CaptureIterableDataset): + state = [ + IteratorState( + num_workers=dataloader.num_workers, + sampler_state=iterator_state, + num_batches_fetched=num_batches_fetched, + worker_id=list(iterator_state.keys())[0], + name=sampler_iter_name, + ) + for sampler_iter_name, iterator_state in state.items() + ] + elif isinstance(dataset, CaptureMapDataset): + ff_sampler = _find_fast_forward_samplers(dl) + state = [ + IteratorState( + num_workers=dataloader.num_workers, + sampler_state=ff_sampler.state_dict(num_batches_fetched), + dataset_state=state, + worker_id=list(state.keys())[0], + num_batches_fetched=num_batches_fetched, + ) + ] + prefetcher._store_dataloader_iter_state(it, state) + return batch + + return wrapper + + iterator._next_data = _next_data_wrapper(iterator._next_data, iterator, dataloader, num_batches_fetched) + + +def _add_capture_metadata_collate(dataloader: DataLoader) -> None: + """Wrap default collate function to retrive captured dataset state dict when fault tolerant is enabled.""" + dataloader.collate_fn = partial( + _capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn + ) diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 11b7c9b1e3..977b763299 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -121,8 +121,6 @@ class GradClipAlgorithmType(LightningEnum): class AutoRestartBatchKeys(LightningEnum): - """ - Defines special dictionary keys used to track sampler progress with multiple workers. - """ + """Defines special dictionary keys used to track captured dataset state with multiple workers.""" - PL_SAMPLERS = "__pl_samplers" + PL_RESTART_META = "__pl_restart_meta" diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 3e4822bbe6..f053f13297 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -14,16 +14,25 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Iterator +from copy import deepcopy +from functools import partial from typing import Any, Generator, List, Optional, Tuple from torch.utils.data.dataloader import DataLoader -from pytorch_lightning.trainer.supporters import CombinedLoader -from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator +from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections +from pytorch_lightning.utilities.auto_restart import ( + _add_capture_metadata_collate, + IteratorState, + MergedIteratorState, + patch_dataloader_iterator, +) from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _fault_tolerant_enabled -class AbstractDataFetcher(ABC): +class AbstractFetcher(ABC): """ This class is used to control batch fetching flow. @@ -48,7 +57,6 @@ class AbstractDataFetcher(ABC): self.batches: List self.fetched: int self.done: bool - self.has_raised: bool self.reset() @@ -58,36 +66,72 @@ class AbstractDataFetcher(ABC): "The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``." ) self.dataloader = dataloader + if isinstance(dataloader, DataLoader) and not isinstance(dataloader.collate_fn, partial): + _add_capture_metadata_collate(dataloader) - def add_batch(self, batch: Any) -> None: + def add_batch(self, batch) -> None: self.batches.append(batch) def fetch_batch(self) -> Any: return self.batches.pop(0) + def _apply_patch(self): + def _apply_patch_fn(loader: DataLoader, iterator: Iterator): + if isinstance(loader, CycleIterator): + loader = loader.loader + # cycle_iterator = iterator + iterator = iterator._loader_iter + + if isinstance(loader, DataLoader) and _fault_tolerant_enabled(): + loader._lightning_fetcher = self + patch_dataloader_iterator(loader, iterator, self) + + apply_to_collections(self.loaders, self.loader_iters, (Iterator, DataLoader), _apply_patch_fn) + + def _store_dataloader_iter_state( + self, dataloader_iter: Iterator, dataloader_iter_states: List[IteratorState] + ) -> None: + if getattr(dataloader_iter, "cache_states", None) is None: + dataloader_iter.cache_states = {} + + if getattr(dataloader_iter, "state", None) is None: + dataloader_iter.state = MergedIteratorState() + + for iter_state in dataloader_iter_states: + iter_name = iter_state.name + if iter_name not in dataloader_iter.cache_states: + dataloader_iter.cache_states[iter_name] = [] + dataloader_iter.cache_states[iter_name].append(iter_state) + + if self.fetched >= self.prefetch_batches: + for iter_state in dataloader_iter_states: + if len(dataloader_iter.state): + dataloader_iter.previous_state = deepcopy(dataloader_iter.state) + iter_name = iter_state.name + state = dataloader_iter.cache_states[iter_name].pop(0) + dataloader_iter.state.update(iter_name, state) + @property def loaders(self) -> List[DataLoader]: - if not self.dataloader: + if self.dataloader is None: raise MisconfigurationException( "The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``." ) if isinstance(self.dataloader, CombinedLoader): loaders = self.dataloader.loaders - elif isinstance(self.dataloader, (tuple, list)): - loaders = self.dataloader else: loaders = [self.dataloader] return loaders @property def loader_iters(self) -> List[Iterator]: - if not self.dataloader: + if self.dataloader is None: raise MisconfigurationException( "The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``." ) - if not self.dataloader_iter: - raise MisconfigurationException("The dataloader_iter isn't available outside the __iter__ context.") + if self.dataloader_iter is None: + raise MisconfigurationException("The `dataloader_iter` isn't available outside the __iter__ context.") if isinstance(self.dataloader, CombinedLoader): loader_iters = self.dataloader_iter.loader_iters @@ -107,15 +151,17 @@ class AbstractDataFetcher(ABC): raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.") self.reset() self.dataloader_iter = iter(self.dataloader) + self._apply_patch() return self.fetching_function() def reset(self) -> None: self.batches: List = [] + self.dataloader: Optional[Iterable] self.fetched: int = 0 self.done: bool = False -class LightningDataFetcher(AbstractDataFetcher): +class LightningDataFetcher(AbstractFetcher): """ This class is used to control batch fetching flow. diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index c72dba8b4b..361600b9cd 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -30,9 +30,8 @@ from torch.utils.data.dataset import Dataset, IterableDataset import tests.helpers.utils as tutils from pytorch_lightning import Callback, seed_everything, Trainer -from pytorch_lightning.trainer.supporters import CombinedLoader -from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.auto_restart import ( + _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, CaptureIterableDataset, @@ -263,7 +262,7 @@ def test_fast_forward_sampler_over_iterative_dataset(num_workers): dataset = CaptureIterableDataset(dataset) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=generator) - Trainer._add_sampler_metadata_collate(dataloader) + _add_capture_metadata_collate(dataloader) iter_dataloader = iter(dataloader) batches = [] @@ -286,7 +285,7 @@ def test_fast_forward_sampler_over_iterative_dataset(num_workers): dataset = CaptureIterableDataset(dataset) dataset.load_state_dict(state_dict) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, generator=generator) - Trainer._add_sampler_metadata_collate(dataloader) + _add_capture_metadata_collate(dataloader) iter_dataloader = iter(dataloader) batches_restart = [] @@ -541,7 +540,7 @@ def _test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(ra ) dataset = CaptureIterableDataset(dataset) dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator) - Trainer._add_sampler_metadata_collate(dataloader) + _add_capture_metadata_collate(dataloader) epoch_results = [] for _ in range(2): @@ -564,8 +563,8 @@ def _test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(ra assert torch.equal( epoch_results[0][0]["data"]["selected_indexes"], epoch_results[0][1]["data"]["selected_indexes"] ) - assert 0 in epoch_results[0][2][AutoRestartBatchKeys.PL_SAMPLERS]["iter_sampler"] # worker id 0 - assert 1 in epoch_results[0][3][AutoRestartBatchKeys.PL_SAMPLERS]["iter_sampler"] # worker id 1 + assert 0 in epoch_results[0][2][AutoRestartBatchKeys.PL_RESTART_META]["iter_sampler"] # worker id 0 + assert 1 in epoch_results[0][3][AutoRestartBatchKeys.PL_RESTART_META]["iter_sampler"] # worker id 1 assert not torch.equal(epoch_results[0][2]["data"][0], epoch_results[0][3]["data"][0]) else: first_task_metadata = all_gather(epoch_results[0][0]["data"]["task_length"], worldsize) @@ -602,7 +601,7 @@ def _test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(ra dataset = CaptureIterableDataset(dataset) dataset.load_state_dict(state_dict) dataloader = DataLoader(dataset, num_workers=num_workers, batch_size=1, generator=generator) - Trainer._add_sampler_metadata_collate(dataloader) + _add_capture_metadata_collate(dataloader) epoch_results_restart = [] for _ in range(2): @@ -661,130 +660,6 @@ def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", w return dataset -def create_dataloader(): - dataset = range(50) - num_workers = 2 - batch_size = 8 - sampler = FastForwardSampler(SequentialSampler(dataset)) - sampler.setup(batch_size) - - dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size) - dataloader.fast_forward_sampler = sampler - - loader_dict = { - "a": [DataLoader(create_iterable_dataset(3, num_workers), num_workers=num_workers, batch_size=3), dataloader], - "b": DataLoader( - create_iterable_dataset(2, num_workers=1, attr_name="custom_sampler"), num_workers=0, batch_size=2 - ), - } - apply_to_collection(loader_dict, DataLoader, Trainer._add_sampler_metadata_collate) - return CombinedLoader(loader_dict) - - -# Lightning will wrap the iterator within a prefect function as follow. -def prefetch_iterator(iterable: Iterable): - it = iter(iterable) - - try: - # the iterator may be empty from the beginning - last = next(it) - except StopIteration: - return - - for val in it: - # yield last and has next - yield last, False, it - last = val - # yield last, no longer has next - yield last, True, it - - -@pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 15 sec and should be skipped in Azure CI") -@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) -@RunIf(min_torch="1.7.0") -def test_combined_dataloader_state_dict_and_reload(): - """ - This test makes sure the CombinedLoader used in the condition of Lightning properly - capture its children DataLoader states. - """ - dataloader = create_dataloader() - - iter_dataloader = iter(prefetch_iterator(dataloader)) - num_batches_processed = 4 - for idx in range(1, num_batches_processed): - _, _, prefetched_iterator = next(iter_dataloader) - - loader_iters = prefetched_iterator._loader_iters - - # when dealing with IterativeDataset, - # the sampler state dict will be attached directly onto the iterator to simplify collection. - - if idx == 1: - assert loader_iters["a"][0]._sampler_state_dict == [{"iter_sampler": {0: {"current_iteration": 3}}}] - assert loader_iters["a"][1]._sampler_state_dict == [] - assert loader_iters["b"]._sampler_state_dict == [{"custom_sampler": {0: {"current_iteration": 2}}}] - elif idx == 2: - assert loader_iters["a"][0]._sampler_state_dict == [ - {"iter_sampler": {0: dict(current_iteration=3), 1: dict(current_iteration=3)}} - ] - assert loader_iters["a"][1]._sampler_state_dict == [] - assert loader_iters["b"]._sampler_state_dict == [{"custom_sampler": {0: {"current_iteration": 4}}}] - else: - assert loader_iters["a"][0]._sampler_state_dict == [ - {"iter_sampler": {0: dict(current_iteration=6), 1: dict(current_iteration=3)}} - ] - assert loader_iters["a"][1]._sampler_state_dict == [] - assert loader_iters["b"]._sampler_state_dict == [{"custom_sampler": {0: {"current_iteration": 6}}}] - - state_dict = dataloader.state_dict(num_batches_processed=3) - - expected = { - "b": {"num_workers": 0, "previous_worker": None, "custom_sampler": {0: dict(current_iteration=6)}}, - "a": [ - { - "num_workers": 2, - "previous_worker": 1, - "iter_sampler": {0: dict(current_iteration=6), 1: dict(current_iteration=3)}, - }, - {"num_workers": 0, "previous_worker": None, 0: dict(current_iteration=24)}, - ], - } - assert state_dict == expected - - dataloader = create_dataloader() - apply_to_collection(dataloader, DataLoader, Trainer._add_sampler_metadata_collate) - dataloader.load_state_dict(state_dict) - - iter_dataloader = iter(prefetch_iterator(dataloader)) - _, _, prefetched_iterator = next(iter_dataloader) - - loader_iters = prefetched_iterator._loader_iters - - assert loader_iters["a"][0]._sampler_state_dict == [ - {"num_workers": 2, "iter_sampler": {0: dict(current_iteration=6), 1: dict(current_iteration=6)}} - ] - assert loader_iters["a"][1]._sampler_state_dict == [] - assert loader_iters["b"]._sampler_state_dict == [ - {"num_workers": 0, "custom_sampler": {0: dict(current_iteration=8)}} - ] - - state_dict = dataloader.state_dict(num_batches_processed=4) - - expected = { - "a": [ - { - "num_workers": 2, - "previous_worker": 0, - "iter_sampler": {0: dict(current_iteration=6), 1: dict(current_iteration=6)}, - }, - {"num_workers": 0, "previous_worker": None, 0: dict(current_iteration=32)}, - ], - "b": {"num_workers": 0, "previous_worker": None, "custom_sampler": {0: dict(current_iteration=8)}}, - } - - assert state_dict == expected - - def test_dataloader_to_state_dict_and_reload(): """ Note: Those utilities are used only with DataLoader wrapping a ``mapping`` based dataset. @@ -804,7 +679,7 @@ def test_dataloader_to_state_dict_and_reload(): _ = next(iter_dataloader) state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader) - assert state_dict == {"num_workers": 0, "previous_worker": None, 0: {"current_iteration": 16}} + assert state_dict[0]["current_iteration"] == 16 dataloader = create_dataloader() dataloader = _dataloader_load_state_dict(dataloader, state_dict) @@ -812,7 +687,7 @@ def test_dataloader_to_state_dict_and_reload(): _ = next(iter_dataloader) state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader) - assert state_dict == {"num_workers": 0, "previous_worker": None, 0: {"current_iteration": 24}} + assert state_dict[0]["current_iteration"] == 24 @RunIf(min_torch="1.7.0") diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index 323245094c..752fafb27d 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -75,13 +75,14 @@ def test_misconfiguration_error(): fetcher = LightningDataFetcher() with pytest.raises( - MisconfigurationException, match="The `DataFetcher` should be setup with an instance of a PyTorch" + MisconfigurationException, + match="The `DataFetcher` should be setup with an instance of a PyTorch ``DataLoader``.", ): fetcher.setup(range(10)) fetcher = LightningDataFetcher() with pytest.raises( - MisconfigurationException, match="The dataloader_iter isn't available outside the __iter__ context." + MisconfigurationException, match="The `dataloader_iter` isn't available outside the __iter__ context." ): loader = DataLoader(range(10)) fetcher.setup(loader)