Fault Tolerant Manual: Add loading to reload the states (#10699)

Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
thomas chaton 2021-11-23 17:18:36 +00:00 committed by GitHub
parent dca1776870
commit b28ab34ff5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 89 additions and 33 deletions

View File

@ -18,7 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647))
* Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/issues/10674))
* Add an utility to collect the states across processes ([#10639](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))
* Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699))
-

View File

@ -22,7 +22,7 @@ from deprecate import void
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.utilities import _update_dataloader_iter
from pytorch_lightning.trainer.progress import BatchProgress
from pytorch_lightning.utilities.auto_restart import MergedIteratorState, reload_dataloader_state_dict
from pytorch_lightning.utilities.auto_restart import _reload_dataloader_state_dict, MergedIteratorState
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
@ -182,7 +182,7 @@ class EvaluationEpochLoop(Loop):
def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher):
if not self.trainer.sanity_checking and self._dataloader_state_dict:
reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict)
_reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict)
self._dataloader_state_dict = None
def _num_completed_batches_reached(self) -> bool:

View File

@ -24,9 +24,9 @@ from torch.utils.data.dataset import IterableDataset
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
from pytorch_lightning.utilities.auto_restart import (
_reload_dataloader_state_dict,
MergedIteratorState,
patch_dataloader_iterator,
reload_dataloader_state_dict,
)
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -403,7 +403,7 @@ class CombinedLoader:
if isinstance(dataloader, CycleIterator):
dataloader = dataloader_to_iter_on.loader
reload_dataloader_state_dict(dataloader, state_dict)
_reload_dataloader_state_dict(dataloader, state_dict)
# We finally spawned the workers if any.
it = iter(dataloader_to_iter_on)

View File

@ -33,7 +33,6 @@ from typing_extensions import Protocol, runtime_checkable
import pytorch_lightning as pl
from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
class FastForwardSampler(Sampler):
@ -564,38 +563,90 @@ def _add_capture_metadata_collate(dataloader: DataLoader) -> None:
)
def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
"""Utility to reload state_dict within dataloader for fault tolerance."""
def _reload_dataloader_state_dict_automatic_map_dataset(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
iterator_state = state_dict["state"][0]
if not _fault_tolerant_training():
return
if not isinstance(iterator_state, IteratorState):
iterator_state = IteratorState.from_state_dict(iterator_state)
# reload sampler state
ff_sampler = _find_fast_forward_samplers(dataloader)
ff_sampler.load_state_dict(iterator_state.sampler_state)
# reload dataset state
dataloader.dataset.load_state_dict(
iterator_state.dataset_state,
latest_worker_id=state_dict["latest_worker_id"],
num_workers=iterator_state.num_workers,
)
def _reload_dataloader_state_dict_automatic_iterable_dataset(
dataset: CaptureIterableDataset, state_dict: Dict[str, Any]
) -> None:
dataset.load_state_dict(
{sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()}
)
def _reload_dataloader_state_dict_automatic(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
dataset = dataloader.dataset
if isinstance(dataset, CaptureMapDataset):
iterator_state = state_dict["state"][0]
if not isinstance(iterator_state, IteratorState):
iterator_state = IteratorState.from_state_dict(iterator_state)
# reload sampler state
ff_sampler = _find_fast_forward_samplers(dataloader)
ff_sampler.load_state_dict(iterator_state.sampler_state)
# reload dataset state
dataset.load_state_dict(
iterator_state.dataset_state,
latest_worker_id=state_dict["latest_worker_id"],
num_workers=iterator_state.num_workers,
)
_reload_dataloader_state_dict_automatic_map_dataset(dataloader, state_dict)
elif isinstance(dataset, CaptureIterableDataset):
dataset.load_state_dict(
{sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()}
)
_reload_dataloader_state_dict_automatic_iterable_dataset(dataset, state_dict)
else:
raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.")
raise MisconfigurationException("This shouldn't be happening. Please, open an issue.")
def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
# In manual mode, we don't wrap the user objects with `CaptureMapDataset` or `CaptureIterableDataset`
# therefore, we need to reload the states manually.
latest_worker_id = state_dict["latest_worker_id"]
num_workers = state_dict["state"][latest_worker_id]["num_workers"]
sampler_state = state_dict["state"][latest_worker_id].get("sampler_state", None)
if sampler_state:
# `sampler_state` keys contain all the DataLoader attribute names
# which matched `_SupportsStateDict` API interface while collecting the `state_dict`.
for dataloader_attr_name in sampler_state:
obj = getattr(dataloader, dataloader_attr_name)
if not isinstance(obj, _SupportsStateDict):
raise MisconfigurationException(
f"The DataLoader attribute {dataloader_attr_name}:{obj} should have a `load_state_dict` method."
)
obj.load_state_dict(sampler_state[dataloader_attr_name])
if not isinstance(dataloader.dataset, _SupportsStateDict):
return
dataset_state = {
worker_id: state_dict["state"][worker_id]["dataset_state"][worker_id]
for worker_id in state_dict["state"].keys()
}
dataloader.dataset.load_state_dict(_rotate_worker_indices(dataset_state, latest_worker_id, num_workers))
def _reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
"""Utility to reload state_dict within dataloader for fault tolerance."""
fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
if not fault_tolerant_mode.is_enabled:
return
if fault_tolerant_mode.is_automatic:
_reload_dataloader_state_dict_automatic(dataloader, state_dict)
elif fault_tolerant_mode.is_manual:
_reload_dataloader_state_dict_manual(dataloader, state_dict)
else:
raise MisconfigurationException("This shouldn't be happening. Please, open an issue.")
def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_workers: int) -> Dict[int, Any]:
@ -638,7 +689,6 @@ class _StatefulDataLoaderIter:
for k, v in self._loader.__dict__.items()
if isinstance(v, _SupportsStateDict) and k != "dataset"
}
self.__accumulate_state(sampler_state)
def _next_index(self) -> Any:

View File

@ -19,6 +19,7 @@ from collections import defaultdict
from collections.abc import Iterable
from contextlib import suppress
from copy import deepcopy
from dataclasses import asdict
from typing import List, Optional
from unittest import mock
from unittest.mock import ANY
@ -42,6 +43,7 @@ from pytorch_lightning.utilities.auto_restart import (
_dataloader_to_state_dict,
_MultiProcessingDataLoaderIterStateful,
_patch_dataloader_get_iterators,
_reload_dataloader_state_dict,
_rotate_worker_indices,
_SingleProcessDataLoaderIterStateful,
_SupportsStateDict,
@ -1289,7 +1291,7 @@ class StatefulRandomDataset(RandomDataset):
return {"counter": self.counter}
def load_state_dict(self, state_dict):
self.counter = state_dict["counter"]
self.counter = state_dict[0]["counter"]
@pytest.mark.parametrize("num_workers", [0])
@ -1319,7 +1321,9 @@ def test_stateful_workers(num_workers):
assert isinstance(dataloader_iter, worker_type)
next(data_fetcher_iter)
state = data_fetcher.dataloader_iter.state.state
reloaded_state = deepcopy(data_fetcher.dataloader_iter.state)
state = reloaded_state.state
assert state[0].dataset_state == {0: {"counter": 1}}
assert state[0].sampler_state["sampler"] == {"counter": 1}
@ -1350,4 +1354,6 @@ def test_stateful_workers(num_workers):
assert not hasattr(DataLoader, "_ori_get_iterator")
assert DataLoader._get_iterator == _get_iterator_fn
_reload_dataloader_state_dict(dataloader, asdict(reloaded_state))
assert dataloader.sampler.counter == dataloader.dataset.counter == 1
data_fetcher.teardown()