Fault Tolerant Manual: Add _rotate_worker_indices utility (#10647)
This commit is contained in:
parent
823bfa6f8a
commit
2036dfb5df
|
@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fault Tolerant Manual
|
||||
* Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646))
|
||||
* Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645))
|
||||
* Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647))
|
||||
|
||||
|
||||
-
|
||||
|
|
|
@ -247,14 +247,7 @@ class CaptureMapDataset(Dataset):
|
|||
def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None:
|
||||
# as workers aren't available, the ``state_dict``` is cached until workers are made available.
|
||||
state_dict = deepcopy(state_dict)
|
||||
|
||||
if num_workers > 0:
|
||||
# remap states to worker ids starting at 0
|
||||
next_worker_id = latest_worker_id + 1
|
||||
old_to_new_worker_id_map = [((next_worker_id + i) % num_workers, i) for i in range(num_workers)]
|
||||
state_dict = {
|
||||
new_id: state_dict[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state_dict
|
||||
}
|
||||
state_dict = _rotate_worker_indices(state_dict, latest_worker_id, num_workers)
|
||||
self._cached_state_dict = state_dict
|
||||
|
||||
def state_dict(self) -> Dict[int, Dict[str, Any]]:
|
||||
|
@ -573,6 +566,20 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A
|
|||
raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.")
|
||||
|
||||
|
||||
def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_workers: int) -> Dict[int, Any]:
|
||||
"""This function is used to rotate the worker indices based on the `latest_worker_id` the training failed
|
||||
on."""
|
||||
if num_workers == 0:
|
||||
return state
|
||||
if latest_worker_id > num_workers - 1:
|
||||
raise MisconfigurationException("The `latest_worker_id` should be within [0, num_workers - 1].")
|
||||
if len(state) != num_workers:
|
||||
raise MisconfigurationException("The `state` should contain `num_workers - 1` values.")
|
||||
next_worker_id = latest_worker_id + 1
|
||||
old_to_new_worker_id_map = [((next_worker_id + i) % num_workers, i) for i in range(num_workers)]
|
||||
return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state}
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class _SupportsStateDict(Protocol):
|
||||
"""This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`."""
|
||||
|
|
|
@ -40,6 +40,7 @@ from pytorch_lightning.utilities.auto_restart import (
|
|||
_add_capture_metadata_collate,
|
||||
_dataloader_load_state_dict,
|
||||
_dataloader_to_state_dict,
|
||||
_rotate_worker_indices,
|
||||
_SupportsStateDict,
|
||||
CaptureIterableDataset,
|
||||
CaptureMapDataset,
|
||||
|
@ -1196,6 +1197,19 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on
|
|||
assert "dataloader_state_dict" in state_dict
|
||||
|
||||
|
||||
def test_rotate_worker_indices():
|
||||
"""This test ensures `worker_id` are rotated properly depending on which one was the latest."""
|
||||
state_dict = {0: 0, 1: 1}
|
||||
assert _rotate_worker_indices(state_dict, 0, 2) == {0: 1, 1: 0}
|
||||
assert _rotate_worker_indices(state_dict, 1, 2) == {0: 0, 1: 1}
|
||||
|
||||
with pytest.raises(MisconfigurationException, match="The `latest_worker_id` should be within"):
|
||||
_rotate_worker_indices(state_dict, 2, 2)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match="The `state` should contain"):
|
||||
_rotate_worker_indices(state_dict, 2, 3)
|
||||
|
||||
|
||||
def test_supports_state_dict_protocol():
|
||||
class StatefulClass:
|
||||
def state_dict(self):
|
||||
|
|
Loading…
Reference in New Issue