diff --git a/CHANGELOG.md b/CHANGELOG.md index a454739c4e..adb1b070dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault Tolerant Manual * Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646)) * Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645)) + * Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647)) - diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 23583852f4..4cb1793643 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -247,14 +247,7 @@ class CaptureMapDataset(Dataset): def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None: # as workers aren't available, the ``state_dict``` is cached until workers are made available. state_dict = deepcopy(state_dict) - - if num_workers > 0: - # remap states to worker ids starting at 0 - next_worker_id = latest_worker_id + 1 - old_to_new_worker_id_map = [((next_worker_id + i) % num_workers, i) for i in range(num_workers)] - state_dict = { - new_id: state_dict[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state_dict - } + state_dict = _rotate_worker_indices(state_dict, latest_worker_id, num_workers) self._cached_state_dict = state_dict def state_dict(self) -> Dict[int, Dict[str, Any]]: @@ -573,6 +566,20 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") +def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_workers: int) -> Dict[int, Any]: + """This function is used to rotate the worker indices based on the `latest_worker_id` the training failed + on.""" + if num_workers == 0: + return state + if latest_worker_id > num_workers - 1: + raise MisconfigurationException("The `latest_worker_id` should be within [0, num_workers - 1].") + if len(state) != num_workers: + raise MisconfigurationException("The `state` should contain `num_workers - 1` values.") + next_worker_id = latest_worker_id + 1 + old_to_new_worker_id_map = [((next_worker_id + i) % num_workers, i) for i in range(num_workers)] + return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state} + + @runtime_checkable class _SupportsStateDict(Protocol): """This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`.""" diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 5152874b39..d9063f90db 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -40,6 +40,7 @@ from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, + _rotate_worker_indices, _SupportsStateDict, CaptureIterableDataset, CaptureMapDataset, @@ -1196,6 +1197,19 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" in state_dict +def test_rotate_worker_indices(): + """This test ensures `worker_id` are rotated properly depending on which one was the latest.""" + state_dict = {0: 0, 1: 1} + assert _rotate_worker_indices(state_dict, 0, 2) == {0: 1, 1: 0} + assert _rotate_worker_indices(state_dict, 1, 2) == {0: 0, 1: 1} + + with pytest.raises(MisconfigurationException, match="The `latest_worker_id` should be within"): + _rotate_worker_indices(state_dict, 2, 2) + + with pytest.raises(MisconfigurationException, match="The `state` should contain"): + _rotate_worker_indices(state_dict, 2, 3) + + def test_supports_state_dict_protocol(): class StatefulClass: def state_dict(self):