827 lines
33 KiB
Python
827 lines
33 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from copy import deepcopy
|
|
from dataclasses import dataclass, field
|
|
from functools import partial, wraps
|
|
from random import getstate as python_get_rng_state
|
|
from random import setstate as python_set_rng_state
|
|
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import (
|
|
BatchSampler,
|
|
Dataset,
|
|
DistributedSampler,
|
|
get_worker_info,
|
|
RandomSampler,
|
|
Sampler,
|
|
SequentialSampler,
|
|
)
|
|
from torch.utils.data.dataloader import (
|
|
_BaseDataLoaderIter,
|
|
_MultiProcessingDataLoaderIter,
|
|
_SingleProcessDataLoaderIter,
|
|
DataLoader,
|
|
IterableDataset,
|
|
)
|
|
|
|
import pytorch_lightning as pl
|
|
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
|
from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero
|
|
from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from pytorch_lightning.utilities.types import _Stateful
|
|
|
|
|
|
class FastForwardSampler(Sampler):
|
|
"""This FastForwardSampler wraps a :class:`torch.utils.data.Sampler` and records the number of iterations
|
|
performed during an epoch.
|
|
|
|
It maintains a state, saved with :meth:`state_dict`, that can be reloaded with
|
|
:meth:`load_state_dict`. If the sampler is used in a multiprocessing context, the ``FastForwardSampler`` will record
|
|
the state of the current worker.
|
|
When reloading, the ``FastForwardSampler`` will "fast-forward" the wrapped sampler by iterating through all the
|
|
samples seen in the last iterations (for the current worker).
|
|
"""
|
|
|
|
def __init__(self, sampler: Union[Sampler, Generator], attr_name: Optional[str] = None) -> None:
|
|
super().__init__(data_source=None)
|
|
self._sampler = sampler
|
|
self.restarting: bool = False
|
|
self._current_iteration = 0
|
|
self._counter = 0
|
|
self._dataloader_batch_size: Optional[int] = None
|
|
self._cached_state_dict: Optional[Dict[int, Any]] = None
|
|
self._attr_name = attr_name
|
|
|
|
def __getattr__(self, key: str) -> Any:
|
|
if key in self.__dict__:
|
|
return self.__dict__[key]
|
|
return getattr(self._sampler, key, None)
|
|
|
|
def setup(self, dataloader_batch_size: Optional[int] = None) -> None:
|
|
"""Setup the ``FastForwardSampler``.
|
|
|
|
This is required only when the provided dataset subclassed
|
|
:class:`torch.utils.data.Dataset`.
|
|
"""
|
|
self._dataloader_batch_size = dataloader_batch_size
|
|
|
|
@property
|
|
def worker_id(self) -> int:
|
|
worker_info = get_worker_info()
|
|
return worker_info.id if worker_info else 0
|
|
|
|
def __iter__(self) -> Iterator[Any]:
|
|
self.sampler_iter = iter(self._sampler)
|
|
self._current_iteration = 0
|
|
self._counter = 0
|
|
return self
|
|
|
|
def __next__(self):
|
|
# the `state dict` was cached as workers were unavailable before.
|
|
if self._cached_state_dict is not None:
|
|
self._load_non_random_state(self._cached_state_dict)
|
|
|
|
while self._counter < self._current_iteration:
|
|
next(self.sampler_iter)
|
|
self._counter += 1
|
|
|
|
# here: i == self._current_iteration
|
|
if self._cached_state_dict is not None:
|
|
self._cached_state_dict = None
|
|
|
|
# recreate iterator to be sure loading is reflected there as well
|
|
self._current_iteration += 1
|
|
self._counter += 1
|
|
has_raised = False
|
|
try:
|
|
return next(self.sampler_iter)
|
|
except StopIteration:
|
|
has_raised = True
|
|
|
|
self._current_iteration = 0
|
|
self._counter = 0
|
|
self._cached_state_dict = None
|
|
self.restarting = False
|
|
if has_raised:
|
|
raise StopIteration
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._sampler)
|
|
|
|
def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, int]]:
|
|
"""Returns the state of the sampler in the current worker.
|
|
|
|
The worker id indexes the state dict.
|
|
"""
|
|
return {self.worker_id: {"current_iteration": self._compute_current_iteration(num_batches_processed)}}
|
|
|
|
def load_state_dict(self, state_dict: Dict[int, Any]) -> None:
|
|
"""Loads the saved state for the wrapped sampler.
|
|
|
|
If the ``state_dict`` contains multiple states, it means there were multiple workers. The state will be cached
|
|
and fully reloaded (fast-forward) the first time :meth:`__iter__` is called.
|
|
"""
|
|
# as workers aren't available, the ``state_dict``` is cached until workers are made available.
|
|
state_dict = deepcopy(state_dict)
|
|
self._cached_state_dict = state_dict
|
|
self.restarting = True
|
|
|
|
def _compute_current_iteration(self, num_batches_processed: Optional[int] = None) -> int:
|
|
"""This function is used to compute the effective iteration.
|
|
|
|
As DataLoader can perform ``prefecthing`` or training can fail while processing a batch, the current iteration
|
|
needs to be computed using the ``num_batches_processed`` processed information.
|
|
"""
|
|
if num_batches_processed is not None:
|
|
current_iteration = num_batches_processed
|
|
else:
|
|
current_iteration = self._current_iteration
|
|
|
|
if self._dataloader_batch_size and num_batches_processed is not None:
|
|
current_iteration *= self._dataloader_batch_size
|
|
|
|
return current_iteration
|
|
|
|
def _load_non_random_state(self, state_dict: Dict[int, Dict[str, Any]]) -> None:
|
|
self._current_iteration = state_dict[self.worker_id]["current_iteration"]
|
|
|
|
|
|
@dataclass(frozen=True, unsafe_hash=True)
|
|
class IteratorState:
|
|
"""The state of an iterator in a single worker process."""
|
|
|
|
dataset_state: Dict[int, Any] = field(default_factory=dict)
|
|
sampler_state: Dict[int, Any] = field(default_factory=dict)
|
|
worker_id: int = 0
|
|
num_workers: int = 0
|
|
num_batches_fetched: int = 0
|
|
name: Optional[str] = None
|
|
|
|
@classmethod
|
|
def from_state_dict(cls, state_dict) -> "IteratorState":
|
|
return cls(**state_dict)
|
|
|
|
|
|
@dataclass
|
|
class MergedIteratorState:
|
|
"""This class is used to hold the current iterator state and lives on the iterator.
|
|
|
|
It holds the current merged states from all worker processes. Once an iterator advances, it can store updates of the
|
|
worker states in this merged iterator state.
|
|
"""
|
|
|
|
state: Union[Dict[Union[int, str], Union[Dict[str, IteratorState], IteratorState]]] = field(default_factory=dict)
|
|
latest_worker_id: int = 0
|
|
represent_map_dataset: Optional[bool] = None
|
|
|
|
def update(self, generator_name: Optional[str], new_state: IteratorState) -> None:
|
|
# a map based dataset doesn't own a generator and therefore `generator_name` should be None.
|
|
self.represent_map_dataset = generator_name is None
|
|
if self.represent_map_dataset:
|
|
state = self.state
|
|
else:
|
|
if generator_name not in self.state:
|
|
self.state[generator_name] = {}
|
|
state = self.state[generator_name]
|
|
|
|
latest_worker_id = new_state.worker_id
|
|
state[latest_worker_id] = new_state
|
|
self.latest_worker_id = latest_worker_id
|
|
|
|
@property
|
|
def sampler_states(self) -> Dict[int, Any]:
|
|
"""Returns the merged sampler states for all worker processes."""
|
|
return {0: self.state[k].sampler_state[0] for k in self.state.keys()}
|
|
|
|
@property
|
|
def dataset_states(self) -> Dict[int, Any]:
|
|
"""Returns the merged dataset states for all worker processes."""
|
|
return {k: self.state[k].dataset_state[k] for k in self.state.keys()}
|
|
|
|
@classmethod
|
|
def from_state_dict(cls, state_dict) -> "MergedIteratorState":
|
|
if state_dict["represent_map_dataset"]:
|
|
state_dict["state"] = {
|
|
worker_id: IteratorState.from_state_dict(state) for worker_id, state in state_dict["state"].items()
|
|
}
|
|
else:
|
|
state_dict["state"] = {
|
|
sampler_name: {
|
|
worker_id: IteratorState.from_state_dict(state) for worker_id, state in worker_state.items()
|
|
}
|
|
for sampler_name, worker_state in state_dict["state"].items()
|
|
}
|
|
return cls(**state_dict)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.state)
|
|
|
|
|
|
class CaptureMapDataset(Dataset):
|
|
"""This class is used to capture the state from the map-based state dataset.
|
|
|
|
Note:
|
|
We currently don't support restoring if we fail during the first `N = num_workers` batches, where
|
|
`num_workers` is the number of workers spawned by the dataloader.
|
|
"""
|
|
|
|
def __init__(self, dataset: Dataset) -> None:
|
|
self.dataset = dataset
|
|
self._cached_state_dict = None
|
|
|
|
@property
|
|
def worker_id(self) -> int:
|
|
worker_info = get_worker_info()
|
|
return worker_info.id if worker_info else 0
|
|
|
|
def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]:
|
|
if self._cached_state_dict is not None:
|
|
if self.worker_id in self._cached_state_dict:
|
|
set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"])
|
|
self._cached_state_dict = None
|
|
|
|
return self.dataset[item]
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.dataset)
|
|
|
|
def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None:
|
|
# as workers aren't available, the ``state_dict``` is cached until workers are made available.
|
|
self._cached_state_dict = _rotate_worker_indices(deepcopy(state_dict), latest_worker_id, num_workers)
|
|
|
|
def state_dict(self) -> Dict[int, Dict[str, Any]]:
|
|
return {self.worker_id: {"rng_states": collect_rng_states()}}
|
|
|
|
|
|
def collect_rng_states() -> Dict[str, Any]:
|
|
"""Collect the global random state of :mod:`torch`, :mod:`numpy` and Python."""
|
|
return {"torch": torch.get_rng_state(), "numpy": np.random.get_state(), "python": python_get_rng_state()}
|
|
|
|
|
|
def set_rng_states(rng_state_dict: Dict[str, Any]) -> None:
|
|
"""Set the global random state of :mod:`torch`, :mod:`numpy` and Python in the current process."""
|
|
torch.set_rng_state(rng_state_dict.get("torch"))
|
|
np.random.set_state(rng_state_dict.get("numpy"))
|
|
version, state, gauss = rng_state_dict.get("python")
|
|
python_set_rng_state((version, tuple(state), gauss))
|
|
|
|
|
|
class CaptureIterableDataset(IterableDataset):
|
|
"""The ``CaptureIterableDataset`` is used to wrap an :class:`torch.utils.data.IterableDataset`.
|
|
|
|
On ``__iter__`` function call, the ``CaptureIterableDataset`` will wrap the wrapped dataset generators into
|
|
``FastForwardSampler`` to keep track of progress. On ``__next__`` function call, the ``CaptureIterableDataset`` will
|
|
return a dictionary containing user data and metadata containing the ``FastForwardSampler`` samplers state_dict.
|
|
"""
|
|
|
|
def __init__(self, dataset: IterableDataset) -> None:
|
|
super().__init__()
|
|
self.dataset = deepcopy(dataset)
|
|
self.samplers: Optional[Dict[str, FastForwardSampler]] = None
|
|
self._state_dict: Optional[Dict[int, Any]] = None
|
|
self._has_wrapped: bool = False
|
|
|
|
@property
|
|
def sampler(self) -> Sampler:
|
|
return self.dataset.sampler
|
|
|
|
def state_dict(self) -> Dict[str, Any]:
|
|
return {k: v.state_dict() for k, v in self.samplers.items()}
|
|
|
|
def load_state_dict(self, state_dict: Dict[int, Any]) -> None:
|
|
self._state_dict = deepcopy(state_dict)
|
|
|
|
def _wrap_generator_samplers(self) -> None:
|
|
self.samplers = {}
|
|
|
|
# access wrapped dataset attributes
|
|
dataset_dict = self.dataset.__dict__
|
|
|
|
# create a dictionary of generator present within the dataset attributes
|
|
dataset_sampler_generators = {k: v for k, v in dataset_dict.items() if isinstance(v, (Generator, Iterator))}
|
|
|
|
# iterate over the generator. If a generator was created from a `Sampler`,
|
|
# it will be wrapped into a `FastForwardSampler`.
|
|
for (generator_attr_name, generator) in dataset_sampler_generators.items():
|
|
|
|
if isinstance(generator, Sampler):
|
|
continue
|
|
|
|
# wrap the generator into a `FastForwardSampler`
|
|
sampler = FastForwardSampler(generator, attr_name=generator_attr_name)
|
|
|
|
# if `CaptureIterableDataset` was available, the sampler should reload its own state.
|
|
if self._state_dict is not None:
|
|
sampler.load_state_dict(self._state_dict[generator_attr_name])
|
|
# store the samplers
|
|
self.samplers[generator_attr_name] = sampler
|
|
|
|
# replace generator with the generator from the `FastForwardSampler`.
|
|
dataset_dict[generator_attr_name] = iter(sampler)
|
|
|
|
self.reset_on_epoch()
|
|
|
|
def reset_on_epoch(self):
|
|
self._state_dict = None
|
|
|
|
def __iter__(self) -> Iterator:
|
|
# create a generator from the wrapped Iterative Dataset
|
|
# if the dataset contained samplers, they will be transformed into generators
|
|
self.iter_data = iter(self.dataset)
|
|
|
|
# wrap any generator associated to a Sampler into a `FastForwardSampler`.
|
|
if isinstance(self.iter_data, Generator):
|
|
raise MisconfigurationException(
|
|
"PyTorch Lightning Fault-Tolerant feature does not support `__iter__` returning a generator."
|
|
" Please use the `__next__` function to fetch the next batch and use a sampler for"
|
|
" doing your iterations."
|
|
)
|
|
self._wrap_generator_samplers()
|
|
return self
|
|
|
|
def __next__(self) -> Dict[str, Any]:
|
|
return next(self.iter_data)
|
|
|
|
|
|
def _find_fast_forward_samplers(dataloader: DataLoader) -> Optional[FastForwardSampler]:
|
|
"""If the ``DataLoader`` is wrapping a mapping based Dataset, return the ``FastForwardSampler``."""
|
|
if isinstance(dataloader.sampler, FastForwardSampler):
|
|
return dataloader.sampler
|
|
|
|
if isinstance(dataloader.batch_sampler, FastForwardSampler):
|
|
return dataloader.batch_sampler
|
|
|
|
|
|
def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str, Any]) -> Iterator:
|
|
"""This function is used to cycle back the DataLoader ``_MultiProcessingDataLoaderIter`` workers and call the
|
|
reset function.
|
|
|
|
Returns:
|
|
iterator: Return the iterator generated from the provided ``DataLoader``.
|
|
"""
|
|
# create iterator from dataloader
|
|
iter_dataloader = iter(dataloader)
|
|
# get current num workers
|
|
num_workers = getattr(iter_dataloader, "_num_workers", 0)
|
|
# as `state_dict` are workers dependent, Lightning doesn't support changing
|
|
# the `num_workers` for Fault-tolerance
|
|
if state_dict["num_workers"] != num_workers:
|
|
raise MisconfigurationException(
|
|
f"The provided `num_workers` {num_workers} doesn't match the one used "
|
|
f"while generating the checkpoint: {state_dict['num_workers']}"
|
|
)
|
|
# when using multiple workers, we will cycle back the worker queue idx to
|
|
# start back on the failed worker.
|
|
if isinstance(iter_dataloader, _MultiProcessingDataLoaderIter):
|
|
# move back to 0
|
|
while next(iter_dataloader._worker_queue_idx_cycle) != 0:
|
|
pass
|
|
# increment previous worker
|
|
if isinstance(state_dict["previous_worker"], int):
|
|
for _ in range(state_dict["previous_worker"] - 1):
|
|
next(iter_dataloader._worker_queue_idx_cycle)
|
|
|
|
# we can finally call reset and apply prefecthing.
|
|
iter_dataloader._reset = iter_dataloader._original_reset
|
|
iter_dataloader._reset(dataloader, first_iter=True)
|
|
# return the iterator
|
|
return iter_dataloader
|
|
|
|
|
|
def _capture_metadata_collate(
|
|
samples: List, dataset: Dataset, collate_fn: Callable, fault_tolerant_mode: _FaultTolerantMode
|
|
) -> Any:
|
|
"""A collate_fn function that adds the state dict of a :class:`CaptureIterableDataset` or
|
|
:class:`CaptureMapDataset` used in the worker processes. This function gets executed within the worker
|
|
processes. The structure will be:
|
|
|
|
.. code-block:: python
|
|
|
|
{
|
|
"data": ..., # data returned by Dataset
|
|
"__pl_restart_meta": {"sampler_name0": state_dict0, "sampler_name1": state_dict1},
|
|
}
|
|
"""
|
|
data = collate_fn(samples)
|
|
metadata = None
|
|
if fault_tolerant_mode.is_automatic:
|
|
metadata = dataset.state_dict()
|
|
else:
|
|
state_dict_fn = getattr(dataset, "state_dict", None)
|
|
info = get_worker_info()
|
|
worker_id = info.id if info else 0
|
|
if state_dict_fn is not None:
|
|
metadata = state_dict_fn()
|
|
if worker_id not in metadata:
|
|
if info and info.num_workers > 1:
|
|
raise MisconfigurationException(
|
|
f"The state_dict returned by {dataset} needs to be indexed by `worker_id` integer keys."
|
|
)
|
|
metadata = {0: metadata}
|
|
if metadata is None:
|
|
metadata = {worker_id: {}}
|
|
|
|
return {"data": data, AutoRestartBatchKeys.PL_RESTART_META: metadata}
|
|
|
|
|
|
# TODO: Merge this code within stateful DataLoaderIter.
|
|
def _next_data_wrapper(
|
|
fn: Callable,
|
|
it: Iterator,
|
|
dl: DataLoader,
|
|
num_batches_fetched: int,
|
|
data_fetcher: "pl.utilities.fetching.AbstractDataFetcher",
|
|
) -> Callable:
|
|
@wraps(fn)
|
|
def wrapper() -> Any:
|
|
nonlocal num_batches_fetched
|
|
|
|
dataset = dl.dataset
|
|
combined_batch = fn()
|
|
|
|
batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_RESTART_META]
|
|
num_batches_fetched += 1
|
|
|
|
if isinstance(dataset, CaptureIterableDataset):
|
|
state = [
|
|
IteratorState(
|
|
num_workers=dl.num_workers,
|
|
sampler_state=iterator_state,
|
|
num_batches_fetched=num_batches_fetched,
|
|
worker_id=list(iterator_state.keys())[0],
|
|
name=sampler_iter_name,
|
|
)
|
|
for sampler_iter_name, iterator_state in state.items()
|
|
]
|
|
elif isinstance(dataset, CaptureMapDataset):
|
|
ff_sampler = _find_fast_forward_samplers(dl)
|
|
state = [
|
|
IteratorState(
|
|
num_workers=dl.num_workers,
|
|
sampler_state=ff_sampler.state_dict(num_batches_fetched),
|
|
dataset_state=state,
|
|
worker_id=list(state.keys())[0],
|
|
num_batches_fetched=num_batches_fetched,
|
|
)
|
|
]
|
|
data_fetcher._store_dataloader_iter_state(it, state)
|
|
return batch
|
|
|
|
return wrapper
|
|
|
|
|
|
def patch_dataloader_iterator(
|
|
dataloader: DataLoader,
|
|
iterator: Iterator,
|
|
data_fetcher: "pl.utilities.fetching.AbstractDataFetcher",
|
|
num_batches_fetched: int = 0,
|
|
) -> None:
|
|
"""Patches the iterator of a PyTorch dataloader by injecting logic for fault-tolerant training when it is
|
|
necessary to remove the sampler state dict from provided data batch.
|
|
|
|
The custom data has this format:
|
|
.. code-block:: python
|
|
{
|
|
"batch": ..., # data returned by DataLoader
|
|
"__pl_restart_meta": {
|
|
"sampler0": {
|
|
0: {"current_iteration": ...},
|
|
1: {"current_iteration": ...},
|
|
},
|
|
"sampler1": ...,
|
|
},
|
|
}
|
|
Each sampler in the worker process tracks the current iteration. We return all of them to the main process
|
|
as part of the sample and then a special collate function :func:`_capture_metadata_collate`
|
|
will extract the current iteration as part of the metadata returned by a custom batch.
|
|
"""
|
|
|
|
if not _FaultTolerantMode.detect_current_mode().is_automatic:
|
|
return
|
|
|
|
assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset))
|
|
iterator._next_data = _next_data_wrapper(
|
|
iterator._next_data, iterator, dataloader, num_batches_fetched, data_fetcher
|
|
)
|
|
|
|
|
|
def _add_capture_metadata_collate(dataloader: DataLoader) -> None:
|
|
"""Wrap default collate function to retrieve captured dataset state dict when fault tolerant is enabled."""
|
|
fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
|
|
collate_fn = dataloader.collate_fn
|
|
if not fault_tolerant_mode.is_enabled or (
|
|
isinstance(collate_fn, partial) and collate_fn.func is _capture_metadata_collate
|
|
):
|
|
return
|
|
dataloader.collate_fn = partial(
|
|
_capture_metadata_collate,
|
|
dataset=dataloader.dataset,
|
|
collate_fn=collate_fn,
|
|
fault_tolerant_mode=fault_tolerant_mode,
|
|
)
|
|
|
|
|
|
def _reload_dataloader_state_dict_automatic_map_dataset(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
|
|
iterator_state = state_dict["state"][0]
|
|
|
|
if not isinstance(iterator_state, IteratorState):
|
|
iterator_state = IteratorState.from_state_dict(iterator_state)
|
|
|
|
# reload sampler state
|
|
ff_sampler = _find_fast_forward_samplers(dataloader)
|
|
ff_sampler.load_state_dict(iterator_state.sampler_state)
|
|
|
|
# reload dataset state
|
|
dataloader.dataset.load_state_dict(
|
|
iterator_state.dataset_state,
|
|
latest_worker_id=state_dict["latest_worker_id"],
|
|
num_workers=iterator_state.num_workers,
|
|
)
|
|
|
|
|
|
def _reload_dataloader_state_dict_automatic_iterable_dataset(
|
|
dataset: CaptureIterableDataset, state_dict: Dict[str, Any]
|
|
) -> None:
|
|
dataset.load_state_dict(
|
|
{sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()}
|
|
)
|
|
|
|
|
|
def _reload_dataloader_state_dict_automatic(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
|
|
dataset = dataloader.dataset
|
|
if isinstance(dataset, CaptureMapDataset):
|
|
_reload_dataloader_state_dict_automatic_map_dataset(dataloader, state_dict)
|
|
|
|
elif isinstance(dataset, CaptureIterableDataset):
|
|
_reload_dataloader_state_dict_automatic_iterable_dataset(dataset, state_dict)
|
|
|
|
else:
|
|
raise MisconfigurationException("This shouldn't be happening. Please, open an issue.")
|
|
|
|
|
|
def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
|
|
# In manual mode, we don't wrap the user objects with `CaptureMapDataset` or `CaptureIterableDataset`
|
|
# therefore, we need to reload the states manually.
|
|
latest_worker_id = state_dict["latest_worker_id"]
|
|
num_workers = state_dict["state"][latest_worker_id]["num_workers"]
|
|
sampler_state = state_dict["state"][latest_worker_id].get("sampler_state", None)
|
|
if sampler_state:
|
|
# `sampler_state` keys contain all the DataLoader attribute names
|
|
# which matched `_Stateful` API interface while collecting the `state_dict`.
|
|
for dataloader_attr_name in sampler_state:
|
|
obj = getattr(dataloader, dataloader_attr_name)
|
|
if not isinstance(obj, _Stateful):
|
|
raise MisconfigurationException(
|
|
f"The DataLoader attribute {dataloader_attr_name}:{obj} should have a `load_state_dict` method."
|
|
)
|
|
|
|
obj.load_state_dict(sampler_state[dataloader_attr_name])
|
|
|
|
if not isinstance(dataloader.dataset, _Stateful):
|
|
return
|
|
|
|
dataset_state = {
|
|
worker_id: state_dict["state"][worker_id]["dataset_state"][worker_id]
|
|
for worker_id in state_dict["state"].keys()
|
|
}
|
|
|
|
dataloader.dataset.load_state_dict(_rotate_worker_indices(dataset_state, latest_worker_id, num_workers))
|
|
|
|
|
|
def _reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
|
|
"""Utility to reload state_dict within dataloader for fault tolerance."""
|
|
|
|
fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
|
|
|
|
if not fault_tolerant_mode.is_enabled:
|
|
return
|
|
|
|
if fault_tolerant_mode.is_automatic:
|
|
_reload_dataloader_state_dict_automatic(dataloader, state_dict)
|
|
|
|
elif fault_tolerant_mode.is_manual:
|
|
_reload_dataloader_state_dict_manual(dataloader, state_dict)
|
|
|
|
else:
|
|
raise MisconfigurationException("This shouldn't be happening. Please, open an issue.")
|
|
|
|
|
|
def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_workers: int) -> Dict[int, Any]:
|
|
"""This function is used to rotate the worker indices based on the `latest_worker_id` the training failed
|
|
on."""
|
|
if num_workers == 0:
|
|
return state
|
|
if latest_worker_id > num_workers - 1:
|
|
raise MisconfigurationException("The `latest_worker_id` should be within [0, num_workers - 1].")
|
|
if len(state) != num_workers:
|
|
raise MisconfigurationException("The `state` should contain `num_workers - 1` values.")
|
|
next_worker_id = latest_worker_id + 1
|
|
old_to_new_worker_id_map = [((next_worker_id + i) % num_workers, i) for i in range(num_workers)]
|
|
return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state}
|
|
|
|
|
|
class _StatefulDataLoaderIter:
|
|
"""This mixin is used to make PyTorch DataLoaderIter stateful."""
|
|
|
|
def __accumulate_state(self, sampler_state: Dict[str, Any]) -> None:
|
|
# store sampler state within a queue alongside its idx.
|
|
self._sampler_state_idx = getattr(self, "_sampler_state_idx", 0) + 1
|
|
self._sampler_state.append((sampler_state, self._sampler_state_idx))
|
|
|
|
def _store_sampler_state(self) -> None:
|
|
"""This function is used to extract the sampler states if any."""
|
|
sampler_state = {
|
|
k: v.state_dict() for k, v in self._loader.__dict__.items() if isinstance(v, _Stateful) and k != "dataset"
|
|
}
|
|
self.__accumulate_state(sampler_state)
|
|
|
|
def _next_index(self) -> Any:
|
|
indexes = super()._next_index()
|
|
self._store_sampler_state()
|
|
return indexes
|
|
|
|
def _prepare_loader(self, loader):
|
|
_add_capture_metadata_collate(loader)
|
|
self._loader = loader
|
|
self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher
|
|
self.num_batches_fetched = 0
|
|
self._sampler_state = []
|
|
self._sampler_state_idx = 0
|
|
|
|
def __del__(self) -> None:
|
|
if isinstance(self._loader.collate_fn, partial):
|
|
self._loader.collate_fn = self._loader.collate_fn.keywords["collate_fn"]
|
|
|
|
def _next_data(self) -> Any:
|
|
combined_batch = super()._next_data()
|
|
|
|
batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_RESTART_META]
|
|
|
|
self.num_batches_fetched += 1
|
|
|
|
sampler_state, sampler_state_idx = self._sampler_state.pop(0)
|
|
# there is no workers within the samplers
|
|
worker_id = list(state.keys())[0]
|
|
|
|
state = [
|
|
IteratorState(
|
|
num_workers=self._loader.num_workers,
|
|
sampler_state=sampler_state,
|
|
dataset_state=state,
|
|
worker_id=worker_id,
|
|
num_batches_fetched=self.num_batches_fetched,
|
|
)
|
|
]
|
|
# ensures there is an alignment between the sampler state and currently fetched batch
|
|
assert sampler_state_idx == self.num_batches_fetched
|
|
self._data_fetcher._store_dataloader_iter_state(self, state)
|
|
return batch
|
|
|
|
|
|
class _SingleProcessDataLoaderIterStateful(_StatefulDataLoaderIter, _SingleProcessDataLoaderIter):
|
|
def __init__(self, loader: DataLoader):
|
|
self._prepare_loader(loader)
|
|
super().__init__(loader)
|
|
|
|
|
|
class _MultiProcessingDataLoaderIterStateful(_StatefulDataLoaderIter, _MultiProcessingDataLoaderIter):
|
|
def __init__(self, loader: DataLoader):
|
|
self._prepare_loader(loader)
|
|
super().__init__(loader)
|
|
|
|
|
|
def _get_iterator(self) -> "_BaseDataLoaderIter":
|
|
if not hasattr(self, "_lightning_fetcher"):
|
|
raise MisconfigurationException(
|
|
"A stateful iterator should be used only when a DataFetcher has been attached to the DataLoader."
|
|
)
|
|
if self.num_workers == 0:
|
|
return _SingleProcessDataLoaderIterStateful(self)
|
|
else:
|
|
if hasattr(self, "check_worker_number_rationality"):
|
|
self.check_worker_number_rationality()
|
|
return _MultiProcessingDataLoaderIterStateful(self)
|
|
|
|
|
|
def _patch_dataloader_get_iterators() -> None:
|
|
"""This function is used to replace the DataLoader iterator by their stateful version."""
|
|
if not _FaultTolerantMode.detect_current_mode().is_manual:
|
|
return
|
|
if not hasattr(DataLoader, "_ori_get_iterator"):
|
|
DataLoader._ori_get_iterator = DataLoader._get_iterator
|
|
DataLoader._get_iterator = _get_iterator
|
|
|
|
|
|
def _teardown_dataloader_get_iterators() -> None:
|
|
"""This function is used to restore the DataLoader `get_iterator` with its original one."""
|
|
# cleanup the get_iterator replacement in case of Fault-tolerance.
|
|
get_iterator = getattr(DataLoader, "_ori_get_iterator", None)
|
|
if get_iterator:
|
|
DataLoader._get_iterator = get_iterator
|
|
del DataLoader._ori_get_iterator
|
|
|
|
|
|
def _validate_iterable_dataset(dataloader: DataLoader) -> None:
|
|
SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler)
|
|
|
|
dataset = dataloader.dataset
|
|
|
|
if getattr(dataset, "__next__", None) is None:
|
|
raise AttributeError(
|
|
"Fault-tolerance doesn't support an `IterableDataset` without `__next__` "
|
|
"method implemented. Hint: We recommend you to move your logic from `__iter__`"
|
|
" inside and rely on a sampler to perform the sample sampling."
|
|
)
|
|
|
|
samplers = {k: v for k, v in dataset.__dict__.items() if isinstance(v, Sampler)}
|
|
|
|
if not samplers:
|
|
raise TypeError("Fault-tolerance doesn't support an IterableDataset without a sampler as attribute.")
|
|
|
|
sampler = [v for v in samplers.values() if type(v) in SUPPORTED_SAMPLERS]
|
|
|
|
if not sampler:
|
|
raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.")
|
|
|
|
if len(sampler) > 1:
|
|
raise ValueError(f"A single sampler is supported within an Iterable Dataset. Found {sampler}.")
|
|
|
|
if type(sampler[0]) is DistributedSampler and sampler.shuffle:
|
|
raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.")
|
|
elif type(sampler[0]) is not SequentialSampler:
|
|
raise TypeError("Only `SequentialSampler` is supported.")
|
|
|
|
|
|
def _validate_map_dataset(dataloader: DataLoader) -> None:
|
|
SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler)
|
|
|
|
sampler = getattr(dataloader, "sampler", None)
|
|
if sampler is not None and type(sampler) not in SUPPORTED_SAMPLERS:
|
|
raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.")
|
|
|
|
batch_sampler = getattr(dataloader, "batch_sampler", None)
|
|
if batch_sampler is not None and type(batch_sampler) is not BatchSampler:
|
|
raise TypeError("Fault-tolerance supports only a `BatchSampler`.")
|
|
|
|
if type(sampler) is DistributedSampler and sampler.shuffle:
|
|
raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.")
|
|
elif type(sampler) is RandomSampler:
|
|
raise TypeError("Only `SequentialSampler` is supported.")
|
|
|
|
|
|
def _validate_fault_tolerant_automatic(dataloader: Iterable, stage: "pl.trainer.states.RunningStage") -> None:
|
|
"""This function is used to validate that Fault-tolerance is possible with the user data."""
|
|
if not _FaultTolerantMode.detect_current_mode().is_automatic:
|
|
return
|
|
|
|
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
|
|
|
|
if isinstance(dataloader, CombinedLoader):
|
|
dataloaders = dataloader.loaders
|
|
else:
|
|
dataloaders = dataloader
|
|
|
|
dl_loaders = []
|
|
|
|
def flatten_dataloader(dataloader: Union[DataLoader, CycleIterator, Iterable]) -> None:
|
|
nonlocal dl_loaders
|
|
if isinstance(dataloader, CycleIterator):
|
|
dataloader = dataloader.loader
|
|
dl_loaders.append(dataloader)
|
|
|
|
apply_to_collection(dataloaders, (DataLoader, CycleIterator), flatten_dataloader)
|
|
|
|
if len(dl_loaders) > 1 and stage == pl.trainer.states.RunningStage.TRAINING:
|
|
raise ValueError("Fault-tolerance supports only a single dataloader.")
|
|
|
|
for dataloader in dl_loaders:
|
|
validator_fn = (
|
|
_validate_iterable_dataset if isinstance(dataloader.dataset, IterableDataset) else _validate_map_dataset
|
|
)
|
|
validator_fn(dataloader)
|
|
|
|
|
|
def _collect_states_on_rank_zero_over_collection(state_dict: Any, key: str = "state") -> Any:
|
|
"""This utility collects the state across processes for a collection of state."""
|
|
|
|
def fn(state: Dict):
|
|
if key in state:
|
|
return _collect_states_on_rank_zero(state)
|
|
return {k: apply_to_collection(v, Dict, fn) for k, v in state.items()}
|
|
|
|
return apply_to_collection(state_dict, Dict, fn)
|