Fault Tolerant Manual: Add support for DDP (#10638)

This commit is contained in:
thomas chaton 2021-11-25 17:31:53 +00:00 committed by GitHub
parent e0b4bb2ea3
commit 3d6262b7a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 74 additions and 27 deletions

View File

@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added a flag `SLURMEnvironment(auto_requeue=True|False)` to control whether Lightning handles the requeuing ([#10601](https://github.com/PyTorchLightning/pytorch-lightning/issues/10601))
@ -21,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699))
* Cleanup some fault tolerant utilities ([#10703](https://github.com/PyTorchLightning/pytorch-lightning/issues/10703))
* Enable Fault Tolerant Manual Training ([#10707](https://github.com/PyTorchLightning/pytorch-lightning/issues/10707))
* Broadcast the `_terminate_gracefully` to all processes and add support for DDP ([#10638](https://github.com/PyTorchLightning/pytorch-lightning/issues/10638))
- Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))

View File

@ -22,8 +22,13 @@ from deprecate import void
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.utilities import _update_dataloader_iter
from pytorch_lightning.trainer.progress import BatchProgress
from pytorch_lightning.utilities.auto_restart import _reload_dataloader_state_dict, MergedIteratorState
from pytorch_lightning.utilities.auto_restart import (
_collect_states_on_rank_zero_over_collection,
_reload_dataloader_state_dict,
MergedIteratorState,
)
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
@ -173,12 +178,16 @@ class EvaluationEpochLoop(Loop):
state_to_save = "state" if self._has_completed() else "previous_state"
state: Optional[MergedIteratorState] = getattr(self._data_fetcher.dataloader_iter, state_to_save, None)
if state:
state_dict["dataloader_state_dict"] = asdict(state)
state_dict["dataloader_state_dict"] = _collect_states_on_rank_zero_over_collection(asdict(state))
return state_dict
def on_load_checkpoint(self, state_dict: Dict) -> None:
# cache the dataloader state dict until the dataloader objects are available
self._dataloader_state_dict = state_dict.get("dataloader_state_dict")
# dataset states are collected across all ranks
dataloader_state_dict = state_dict.get("dataloader_state_dict", None)
if not _fault_tolerant_training() or not dataloader_state_dict:
return
self._dataloader_state_dict = dataloader_state_dict[self.trainer.global_rank]
def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher):
if not self.trainer.sanity_checking and self._dataloader_state_dict:

View File

@ -25,6 +25,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import ResultC
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import _collect_states_on_rank_zero_over_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
@ -320,8 +321,9 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
or self.batch_progress.current.ready == 0 # did not start
):
return state_dict
state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(
has_completed=self._has_completed()
state_dict["dataloader_state_dict"] = _collect_states_on_rank_zero_over_collection(
self.trainer.train_dataloader.state_dict(has_completed=self._has_completed())
)
return state_dict

View File

@ -19,6 +19,7 @@ from weakref import proxy
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.auto_restart import _teardown_dataloader_get_iterators
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import (
AbstractDataFetcher,
@ -254,6 +255,7 @@ class DataConnector:
if self.sanity_check_data_fetcher:
self.sanity_check_data_fetcher.teardown()
self.sanity_check_data_fetcher = None
_teardown_dataloader_get_iterators()
@dataclass

View File

@ -29,6 +29,7 @@ from pytorch_lightning.utilities.auto_restart import (
patch_dataloader_iterator,
)
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
@ -403,6 +404,10 @@ class CombinedLoader:
if isinstance(dataloader, CycleIterator):
dataloader = dataloader_to_iter_on.loader
# dataset states are collected across all ranks
rank = torch.distributed.get_rank() if distributed_available() else 0
state_dict = state_dict[rank]
_reload_dataloader_state_dict(dataloader, state_dict)
# We finally spawned the workers if any.

View File

@ -2100,10 +2100,17 @@ class Trainer(
return active_loop._results
def _exit_gracefully_on_signal(self) -> None:
if _fault_tolerant_training() and self._terminate_gracefully:
caller = inspect.stack()[1]
class_name = caller[0].f_locals["self"].__class__.__name__
raise ExitGracefullyException(f"Exiting gracefully on {class_name}:{caller.function}")
if not _fault_tolerant_training():
return
if not self._should_terminated_gracefully():
return
caller = inspect.stack()[1]
class_name = caller[0].f_locals["self"].__class__.__name__
raise ExitGracefullyException(f"Exiting gracefully on {class_name}:{caller.function}")
def _should_terminated_gracefully(self) -> bool:
value = torch.tensor(self._terminate_gracefully, device=self.training_type_plugin.root_device)
return self.training_type_plugin.reduce(value, reduce_op="sum") > 0
@property
def weights_summary(self) -> Optional[str]:

View File

@ -31,6 +31,8 @@ from torch.utils.data.dataloader import (
from typing_extensions import Protocol, runtime_checkable
import pytorch_lightning as pl
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero
from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -737,3 +739,14 @@ def _teardown_dataloader_get_iterators() -> None:
if get_iterator:
DataLoader._get_iterator = get_iterator
del DataLoader._ori_get_iterator
def _collect_states_on_rank_zero_over_collection(state_dict: Any, key: str = "state") -> Any:
"""This utility collects the state across processes for a collection of state."""
def fn(state: Dict):
if key in state:
return _collect_states_on_rank_zero(state)
return {k: apply_to_collection(v, Dict, fn) for k, v in state.items()}
return apply_to_collection(state_dict, Dict, fn)

View File

@ -378,7 +378,14 @@ def init_dist_connection(
)
def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) -> Optional[Dict[int, Any]]:
def _broadcast_object_list(obj: Any, rank: int) -> Any:
objects = [obj if torch.distributed.get_rank() == rank else None]
torch.distributed.broadcast_object_list(objects, src=rank)
return objects[0]
# TODO: Refactor with the Strategy Collectives once finalized.
def _collect_states_on_rank_zero(state: Dict[str, Any]) -> Dict[int, Any]:
"""This distributed utility collects dictionary state across all processes.
Args:
@ -391,13 +398,4 @@ def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) ->
"""
if not distributed_available():
return {0: state}
states = {}
current_rank = torch.distributed.get_rank()
for rank in range(1, torch.distributed.get_world_size()):
objects = [state if current_rank == rank else None]
torch.distributed.broadcast_object_list(objects, src=rank, device=device)
states[rank] = objects[0]
if current_rank != 0:
return None
states[0] = state
return states
return {rank: _broadcast_object_list(state, rank) for rank in range(torch.distributed.get_world_size())}

View File

@ -14,7 +14,6 @@
"""General utilities."""
import importlib
import operator
import os
import platform
import sys
from importlib.util import find_spec
@ -111,4 +110,6 @@ else:
# experimental feature within PyTorch Lightning.
def _fault_tolerant_training() -> bool:
return bool(int(os.getenv("PL_FAULT_TOLERANT_TRAINING", 0)))
from pytorch_lightning.utilities.enums import _FaultTolerantMode
return _FaultTolerantMode.detect_current_mode().is_enabled

View File

@ -39,6 +39,7 @@ from pytorch_lightning import Callback, LightningModule, seed_everything, Traine
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities.auto_restart import (
_add_capture_metadata_collate,
_collect_states_on_rank_zero_over_collection,
_MultiProcessingDataLoaderIterStateful,
_patch_dataloader_get_iterators,
_reload_dataloader_state_dict,
@ -1254,6 +1255,13 @@ class StatefulRandomDataset(RandomDataset):
self.counter = state_dict[0]["counter"]
def test_collect_states_with_collection():
state = {"state": 0}
collection = [{"a": state, "b": [{"a": state}]}]
generated = _collect_states_on_rank_zero_over_collection(collection)
assert generated == [{"a": {0: state}, "b": [{"a": {0: state}}]}]
@pytest.mark.parametrize("num_workers", [0])
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"})
def test_stateful_workers(num_workers):

View File

@ -64,15 +64,14 @@ def test_rank_zero_none_set(rank_key, rank):
def _test_collect_states(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
torch.cuda.set_device(f"cuda:{rank}")
# initialize the process group
torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)
state = {"something": torch.tensor([rank])}
collected_state = _collect_states_on_rank_zero(state, device=torch.device(f"cuda:{rank}"))
if rank == 0:
assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}}
else:
assert collected_state is None
collected_state = _collect_states_on_rank_zero(state)
assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}}
@RunIf(skip_windows=True, min_gpus=2, min_torch="1.10")