Fault Tolerant Manual: Add loading to reload the states (#10699)
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
dca1776870
commit
b28ab34ff5
|
@ -18,7 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647))
|
||||
* Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/issues/10674))
|
||||
* Add an utility to collect the states across processes ([#10639](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))
|
||||
|
||||
* Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699))
|
||||
|
||||
-
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ from deprecate import void
|
|||
from pytorch_lightning.loops.base import Loop
|
||||
from pytorch_lightning.loops.utilities import _update_dataloader_iter
|
||||
from pytorch_lightning.trainer.progress import BatchProgress
|
||||
from pytorch_lightning.utilities.auto_restart import MergedIteratorState, reload_dataloader_state_dict
|
||||
from pytorch_lightning.utilities.auto_restart import _reload_dataloader_state_dict, MergedIteratorState
|
||||
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
|
||||
|
@ -182,7 +182,7 @@ class EvaluationEpochLoop(Loop):
|
|||
|
||||
def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher):
|
||||
if not self.trainer.sanity_checking and self._dataloader_state_dict:
|
||||
reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict)
|
||||
_reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict)
|
||||
self._dataloader_state_dict = None
|
||||
|
||||
def _num_completed_batches_reached(self) -> bool:
|
||||
|
|
|
@ -24,9 +24,9 @@ from torch.utils.data.dataset import IterableDataset
|
|||
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
|
||||
from pytorch_lightning.utilities.auto_restart import (
|
||||
_reload_dataloader_state_dict,
|
||||
MergedIteratorState,
|
||||
patch_dataloader_iterator,
|
||||
reload_dataloader_state_dict,
|
||||
)
|
||||
from pytorch_lightning.utilities.data import get_len
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -403,7 +403,7 @@ class CombinedLoader:
|
|||
if isinstance(dataloader, CycleIterator):
|
||||
dataloader = dataloader_to_iter_on.loader
|
||||
|
||||
reload_dataloader_state_dict(dataloader, state_dict)
|
||||
_reload_dataloader_state_dict(dataloader, state_dict)
|
||||
|
||||
# We finally spawned the workers if any.
|
||||
it = iter(dataloader_to_iter_on)
|
||||
|
|
|
@ -33,7 +33,6 @@ from typing_extensions import Protocol, runtime_checkable
|
|||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
|
||||
|
||||
class FastForwardSampler(Sampler):
|
||||
|
@ -564,38 +563,90 @@ def _add_capture_metadata_collate(dataloader: DataLoader) -> None:
|
|||
)
|
||||
|
||||
|
||||
def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
|
||||
"""Utility to reload state_dict within dataloader for fault tolerance."""
|
||||
def _reload_dataloader_state_dict_automatic_map_dataset(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
|
||||
iterator_state = state_dict["state"][0]
|
||||
|
||||
if not _fault_tolerant_training():
|
||||
return
|
||||
if not isinstance(iterator_state, IteratorState):
|
||||
iterator_state = IteratorState.from_state_dict(iterator_state)
|
||||
|
||||
# reload sampler state
|
||||
ff_sampler = _find_fast_forward_samplers(dataloader)
|
||||
ff_sampler.load_state_dict(iterator_state.sampler_state)
|
||||
|
||||
# reload dataset state
|
||||
dataloader.dataset.load_state_dict(
|
||||
iterator_state.dataset_state,
|
||||
latest_worker_id=state_dict["latest_worker_id"],
|
||||
num_workers=iterator_state.num_workers,
|
||||
)
|
||||
|
||||
|
||||
def _reload_dataloader_state_dict_automatic_iterable_dataset(
|
||||
dataset: CaptureIterableDataset, state_dict: Dict[str, Any]
|
||||
) -> None:
|
||||
dataset.load_state_dict(
|
||||
{sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()}
|
||||
)
|
||||
|
||||
|
||||
def _reload_dataloader_state_dict_automatic(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
|
||||
dataset = dataloader.dataset
|
||||
|
||||
if isinstance(dataset, CaptureMapDataset):
|
||||
iterator_state = state_dict["state"][0]
|
||||
|
||||
if not isinstance(iterator_state, IteratorState):
|
||||
iterator_state = IteratorState.from_state_dict(iterator_state)
|
||||
|
||||
# reload sampler state
|
||||
ff_sampler = _find_fast_forward_samplers(dataloader)
|
||||
ff_sampler.load_state_dict(iterator_state.sampler_state)
|
||||
|
||||
# reload dataset state
|
||||
dataset.load_state_dict(
|
||||
iterator_state.dataset_state,
|
||||
latest_worker_id=state_dict["latest_worker_id"],
|
||||
num_workers=iterator_state.num_workers,
|
||||
)
|
||||
_reload_dataloader_state_dict_automatic_map_dataset(dataloader, state_dict)
|
||||
|
||||
elif isinstance(dataset, CaptureIterableDataset):
|
||||
dataset.load_state_dict(
|
||||
{sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()}
|
||||
)
|
||||
_reload_dataloader_state_dict_automatic_iterable_dataset(dataset, state_dict)
|
||||
|
||||
else:
|
||||
raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.")
|
||||
raise MisconfigurationException("This shouldn't be happening. Please, open an issue.")
|
||||
|
||||
|
||||
def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
|
||||
# In manual mode, we don't wrap the user objects with `CaptureMapDataset` or `CaptureIterableDataset`
|
||||
# therefore, we need to reload the states manually.
|
||||
|
||||
latest_worker_id = state_dict["latest_worker_id"]
|
||||
num_workers = state_dict["state"][latest_worker_id]["num_workers"]
|
||||
sampler_state = state_dict["state"][latest_worker_id].get("sampler_state", None)
|
||||
if sampler_state:
|
||||
# `sampler_state` keys contain all the DataLoader attribute names
|
||||
# which matched `_SupportsStateDict` API interface while collecting the `state_dict`.
|
||||
for dataloader_attr_name in sampler_state:
|
||||
obj = getattr(dataloader, dataloader_attr_name)
|
||||
if not isinstance(obj, _SupportsStateDict):
|
||||
raise MisconfigurationException(
|
||||
f"The DataLoader attribute {dataloader_attr_name}:{obj} should have a `load_state_dict` method."
|
||||
)
|
||||
|
||||
obj.load_state_dict(sampler_state[dataloader_attr_name])
|
||||
|
||||
if not isinstance(dataloader.dataset, _SupportsStateDict):
|
||||
return
|
||||
|
||||
dataset_state = {
|
||||
worker_id: state_dict["state"][worker_id]["dataset_state"][worker_id]
|
||||
for worker_id in state_dict["state"].keys()
|
||||
}
|
||||
|
||||
dataloader.dataset.load_state_dict(_rotate_worker_indices(dataset_state, latest_worker_id, num_workers))
|
||||
|
||||
|
||||
def _reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
|
||||
"""Utility to reload state_dict within dataloader for fault tolerance."""
|
||||
|
||||
fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
|
||||
|
||||
if not fault_tolerant_mode.is_enabled:
|
||||
return
|
||||
|
||||
if fault_tolerant_mode.is_automatic:
|
||||
_reload_dataloader_state_dict_automatic(dataloader, state_dict)
|
||||
|
||||
elif fault_tolerant_mode.is_manual:
|
||||
_reload_dataloader_state_dict_manual(dataloader, state_dict)
|
||||
|
||||
else:
|
||||
raise MisconfigurationException("This shouldn't be happening. Please, open an issue.")
|
||||
|
||||
|
||||
def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_workers: int) -> Dict[int, Any]:
|
||||
|
@ -638,7 +689,6 @@ class _StatefulDataLoaderIter:
|
|||
for k, v in self._loader.__dict__.items()
|
||||
if isinstance(v, _SupportsStateDict) and k != "dataset"
|
||||
}
|
||||
|
||||
self.__accumulate_state(sampler_state)
|
||||
|
||||
def _next_index(self) -> Any:
|
||||
|
|
|
@ -19,6 +19,7 @@ from collections import defaultdict
|
|||
from collections.abc import Iterable
|
||||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
from dataclasses import asdict
|
||||
from typing import List, Optional
|
||||
from unittest import mock
|
||||
from unittest.mock import ANY
|
||||
|
@ -42,6 +43,7 @@ from pytorch_lightning.utilities.auto_restart import (
|
|||
_dataloader_to_state_dict,
|
||||
_MultiProcessingDataLoaderIterStateful,
|
||||
_patch_dataloader_get_iterators,
|
||||
_reload_dataloader_state_dict,
|
||||
_rotate_worker_indices,
|
||||
_SingleProcessDataLoaderIterStateful,
|
||||
_SupportsStateDict,
|
||||
|
@ -1289,7 +1291,7 @@ class StatefulRandomDataset(RandomDataset):
|
|||
return {"counter": self.counter}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.counter = state_dict["counter"]
|
||||
self.counter = state_dict[0]["counter"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [0])
|
||||
|
@ -1319,7 +1321,9 @@ def test_stateful_workers(num_workers):
|
|||
assert isinstance(dataloader_iter, worker_type)
|
||||
|
||||
next(data_fetcher_iter)
|
||||
state = data_fetcher.dataloader_iter.state.state
|
||||
|
||||
reloaded_state = deepcopy(data_fetcher.dataloader_iter.state)
|
||||
state = reloaded_state.state
|
||||
assert state[0].dataset_state == {0: {"counter": 1}}
|
||||
assert state[0].sampler_state["sampler"] == {"counter": 1}
|
||||
|
||||
|
@ -1350,4 +1354,6 @@ def test_stateful_workers(num_workers):
|
|||
assert not hasattr(DataLoader, "_ori_get_iterator")
|
||||
assert DataLoader._get_iterator == _get_iterator_fn
|
||||
|
||||
_reload_dataloader_state_dict(dataloader, asdict(reloaded_state))
|
||||
assert dataloader.sampler.counter == dataloader.dataset.counter == 1
|
||||
data_fetcher.teardown()
|
||||
|
|
Loading…
Reference in New Issue