From 3d6262b7a91215e72019d720e742e6261e1636dc Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 25 Nov 2021 17:31:53 +0000 Subject: [PATCH] Fault Tolerant Manual: Add support for DDP (#10638) --- CHANGELOG.md | 3 +++ .../loops/epoch/evaluation_epoch_loop.py | 15 +++++++++++--- .../loops/epoch/training_epoch_loop.py | 6 ++++-- .../trainer/connectors/data_connector.py | 2 ++ pytorch_lightning/trainer/supporters.py | 5 +++++ pytorch_lightning/trainer/trainer.py | 15 ++++++++++---- pytorch_lightning/utilities/auto_restart.py | 13 ++++++++++++ pytorch_lightning/utilities/distributed.py | 20 +++++++++---------- pytorch_lightning/utilities/imports.py | 5 +++-- tests/utilities/test_auto_restart.py | 8 ++++++++ tests/utilities/test_distributed.py | 9 ++++----- 11 files changed, 74 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 43c0c6ab14..3a136fe023 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added + - Added a flag `SLURMEnvironment(auto_requeue=True|False)` to control whether Lightning handles the requeuing ([#10601](https://github.com/PyTorchLightning/pytorch-lightning/issues/10601)) @@ -21,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699)) * Cleanup some fault tolerant utilities ([#10703](https://github.com/PyTorchLightning/pytorch-lightning/issues/10703)) * Enable Fault Tolerant Manual Training ([#10707](https://github.com/PyTorchLightning/pytorch-lightning/issues/10707)) + * Broadcast the `_terminate_gracefully` to all processes and add support for DDP ([#10638](https://github.com/PyTorchLightning/pytorch-lightning/issues/10638)) + - Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639)) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 2fc572ea25..b7bfc1e0ed 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -22,8 +22,13 @@ from deprecate import void from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.utilities import _update_dataloader_iter from pytorch_lightning.trainer.progress import BatchProgress -from pytorch_lightning.utilities.auto_restart import _reload_dataloader_state_dict, MergedIteratorState +from pytorch_lightning.utilities.auto_restart import ( + _collect_states_on_rank_zero_over_collection, + _reload_dataloader_state_dict, + MergedIteratorState, +) from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher +from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT @@ -173,12 +178,16 @@ class EvaluationEpochLoop(Loop): state_to_save = "state" if self._has_completed() else "previous_state" state: Optional[MergedIteratorState] = getattr(self._data_fetcher.dataloader_iter, state_to_save, None) if state: - state_dict["dataloader_state_dict"] = asdict(state) + state_dict["dataloader_state_dict"] = _collect_states_on_rank_zero_over_collection(asdict(state)) return state_dict def on_load_checkpoint(self, state_dict: Dict) -> None: # cache the dataloader state dict until the dataloader objects are available - self._dataloader_state_dict = state_dict.get("dataloader_state_dict") + # dataset states are collected across all ranks + dataloader_state_dict = state_dict.get("dataloader_state_dict", None) + if not _fault_tolerant_training() or not dataloader_state_dict: + return + self._dataloader_state_dict = dataloader_state_dict[self.trainer.global_rank] def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher): if not self.trainer.sanity_checking and self._dataloader_state_dict: diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 8ddca3ad50..a75ad470c2 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -25,6 +25,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import ResultC from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.auto_restart import _collect_states_on_rank_zero_over_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import AbstractDataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden @@ -320,8 +321,9 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): or self.batch_progress.current.ready == 0 # did not start ): return state_dict - state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict( - has_completed=self._has_completed() + + state_dict["dataloader_state_dict"] = _collect_states_on_rank_zero_over_collection( + self.trainer.train_dataloader.state_dict(has_completed=self._has_completed()) ) return state_dict diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index deee64c90f..e6f76e0403 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -19,6 +19,7 @@ from weakref import proxy import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.auto_restart import _teardown_dataloader_get_iterators from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import ( AbstractDataFetcher, @@ -254,6 +255,7 @@ class DataConnector: if self.sanity_check_data_fetcher: self.sanity_check_data_fetcher.teardown() self.sanity_check_data_fetcher = None + _teardown_dataloader_get_iterators() @dataclass diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index d65bc08e66..df86ea157f 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -29,6 +29,7 @@ from pytorch_lightning.utilities.auto_restart import ( patch_dataloader_iterator, ) from pytorch_lightning.utilities.data import get_len +from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -403,6 +404,10 @@ class CombinedLoader: if isinstance(dataloader, CycleIterator): dataloader = dataloader_to_iter_on.loader + # dataset states are collected across all ranks + rank = torch.distributed.get_rank() if distributed_available() else 0 + state_dict = state_dict[rank] + _reload_dataloader_state_dict(dataloader, state_dict) # We finally spawned the workers if any. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 18f13a75bf..73e9437040 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2100,10 +2100,17 @@ class Trainer( return active_loop._results def _exit_gracefully_on_signal(self) -> None: - if _fault_tolerant_training() and self._terminate_gracefully: - caller = inspect.stack()[1] - class_name = caller[0].f_locals["self"].__class__.__name__ - raise ExitGracefullyException(f"Exiting gracefully on {class_name}:{caller.function}") + if not _fault_tolerant_training(): + return + if not self._should_terminated_gracefully(): + return + caller = inspect.stack()[1] + class_name = caller[0].f_locals["self"].__class__.__name__ + raise ExitGracefullyException(f"Exiting gracefully on {class_name}:{caller.function}") + + def _should_terminated_gracefully(self) -> bool: + value = torch.tensor(self._terminate_gracefully, device=self.training_type_plugin.root_device) + return self.training_type_plugin.reduce(value, reduce_op="sum") > 0 @property def weights_summary(self) -> Optional[str]: diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 9f99634bd1..84f0c9dece 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -31,6 +31,8 @@ from torch.utils.data.dataloader import ( from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -737,3 +739,14 @@ def _teardown_dataloader_get_iterators() -> None: if get_iterator: DataLoader._get_iterator = get_iterator del DataLoader._ori_get_iterator + + +def _collect_states_on_rank_zero_over_collection(state_dict: Any, key: str = "state") -> Any: + """This utility collects the state across processes for a collection of state.""" + + def fn(state: Dict): + if key in state: + return _collect_states_on_rank_zero(state) + return {k: apply_to_collection(v, Dict, fn) for k, v in state.items()} + + return apply_to_collection(state_dict, Dict, fn) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 7c6e4f4048..5612752569 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -378,7 +378,14 @@ def init_dist_connection( ) -def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) -> Optional[Dict[int, Any]]: +def _broadcast_object_list(obj: Any, rank: int) -> Any: + objects = [obj if torch.distributed.get_rank() == rank else None] + torch.distributed.broadcast_object_list(objects, src=rank) + return objects[0] + + +# TODO: Refactor with the Strategy Collectives once finalized. +def _collect_states_on_rank_zero(state: Dict[str, Any]) -> Dict[int, Any]: """This distributed utility collects dictionary state across all processes. Args: @@ -391,13 +398,4 @@ def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) -> """ if not distributed_available(): return {0: state} - states = {} - current_rank = torch.distributed.get_rank() - for rank in range(1, torch.distributed.get_world_size()): - objects = [state if current_rank == rank else None] - torch.distributed.broadcast_object_list(objects, src=rank, device=device) - states[rank] = objects[0] - if current_rank != 0: - return None - states[0] = state - return states + return {rank: _broadcast_object_list(state, rank) for rank in range(torch.distributed.get_world_size())} diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index aa6349b5d6..49c94d87e6 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -14,7 +14,6 @@ """General utilities.""" import importlib import operator -import os import platform import sys from importlib.util import find_spec @@ -111,4 +110,6 @@ else: # experimental feature within PyTorch Lightning. def _fault_tolerant_training() -> bool: - return bool(int(os.getenv("PL_FAULT_TOLERANT_TRAINING", 0))) + from pytorch_lightning.utilities.enums import _FaultTolerantMode + + return _FaultTolerantMode.detect_current_mode().is_enabled diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index c69b70b65b..58a11d0de6 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -39,6 +39,7 @@ from pytorch_lightning import Callback, LightningModule, seed_everything, Traine from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, + _collect_states_on_rank_zero_over_collection, _MultiProcessingDataLoaderIterStateful, _patch_dataloader_get_iterators, _reload_dataloader_state_dict, @@ -1254,6 +1255,13 @@ class StatefulRandomDataset(RandomDataset): self.counter = state_dict[0]["counter"] +def test_collect_states_with_collection(): + state = {"state": 0} + collection = [{"a": state, "b": [{"a": state}]}] + generated = _collect_states_on_rank_zero_over_collection(collection) + assert generated == [{"a": {0: state}, "b": [{"a": {0: state}}]}] + + @pytest.mark.parametrize("num_workers", [0]) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}) def test_stateful_workers(num_workers): diff --git a/tests/utilities/test_distributed.py b/tests/utilities/test_distributed.py index a48b4486a4..6226aadecb 100644 --- a/tests/utilities/test_distributed.py +++ b/tests/utilities/test_distributed.py @@ -64,15 +64,14 @@ def test_rank_zero_none_set(rank_key, rank): def _test_collect_states(rank, world_size): os.environ["MASTER_ADDR"] = "localhost" + torch.cuda.set_device(f"cuda:{rank}") + # initialize the process group torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size) state = {"something": torch.tensor([rank])} - collected_state = _collect_states_on_rank_zero(state, device=torch.device(f"cuda:{rank}")) - if rank == 0: - assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}} - else: - assert collected_state is None + collected_state = _collect_states_on_rank_zero(state) + assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}} @RunIf(skip_windows=True, min_gpus=2, min_torch="1.10")