Fault Tolerant Manual: Add _rotate_worker_indices utility (#10647)

This commit is contained in:
thomas chaton 2021-11-22 19:52:04 +00:00 committed by GitHub
parent 823bfa6f8a
commit 2036dfb5df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 8 deletions

View File

@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fault Tolerant Manual
* Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646))
* Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645))
* Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647))
-

View File

@ -247,14 +247,7 @@ class CaptureMapDataset(Dataset):
def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None:
# as workers aren't available, the ``state_dict``` is cached until workers are made available.
state_dict = deepcopy(state_dict)
if num_workers > 0:
# remap states to worker ids starting at 0
next_worker_id = latest_worker_id + 1
old_to_new_worker_id_map = [((next_worker_id + i) % num_workers, i) for i in range(num_workers)]
state_dict = {
new_id: state_dict[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state_dict
}
state_dict = _rotate_worker_indices(state_dict, latest_worker_id, num_workers)
self._cached_state_dict = state_dict
def state_dict(self) -> Dict[int, Dict[str, Any]]:
@ -573,6 +566,20 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A
raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.")
def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_workers: int) -> Dict[int, Any]:
"""This function is used to rotate the worker indices based on the `latest_worker_id` the training failed
on."""
if num_workers == 0:
return state
if latest_worker_id > num_workers - 1:
raise MisconfigurationException("The `latest_worker_id` should be within [0, num_workers - 1].")
if len(state) != num_workers:
raise MisconfigurationException("The `state` should contain `num_workers - 1` values.")
next_worker_id = latest_worker_id + 1
old_to_new_worker_id_map = [((next_worker_id + i) % num_workers, i) for i in range(num_workers)]
return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state}
@runtime_checkable
class _SupportsStateDict(Protocol):
"""This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`."""

View File

@ -40,6 +40,7 @@ from pytorch_lightning.utilities.auto_restart import (
_add_capture_metadata_collate,
_dataloader_load_state_dict,
_dataloader_to_state_dict,
_rotate_worker_indices,
_SupportsStateDict,
CaptureIterableDataset,
CaptureMapDataset,
@ -1196,6 +1197,19 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on
assert "dataloader_state_dict" in state_dict
def test_rotate_worker_indices():
"""This test ensures `worker_id` are rotated properly depending on which one was the latest."""
state_dict = {0: 0, 1: 1}
assert _rotate_worker_indices(state_dict, 0, 2) == {0: 1, 1: 0}
assert _rotate_worker_indices(state_dict, 1, 2) == {0: 0, 1: 1}
with pytest.raises(MisconfigurationException, match="The `latest_worker_id` should be within"):
_rotate_worker_indices(state_dict, 2, 2)
with pytest.raises(MisconfigurationException, match="The `state` should contain"):
_rotate_worker_indices(state_dict, 2, 3)
def test_supports_state_dict_protocol():
class StatefulClass:
def state_dict(self):