Fault Tolerant Manual: Add support for DDP (#10638)
This commit is contained in:
parent
e0b4bb2ea3
commit
3d6262b7a9
|
@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
|
||||
- Added a flag `SLURMEnvironment(auto_requeue=True|False)` to control whether Lightning handles the requeuing ([#10601](https://github.com/PyTorchLightning/pytorch-lightning/issues/10601))
|
||||
|
||||
|
||||
|
@ -21,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699))
|
||||
* Cleanup some fault tolerant utilities ([#10703](https://github.com/PyTorchLightning/pytorch-lightning/issues/10703))
|
||||
* Enable Fault Tolerant Manual Training ([#10707](https://github.com/PyTorchLightning/pytorch-lightning/issues/10707))
|
||||
* Broadcast the `_terminate_gracefully` to all processes and add support for DDP ([#10638](https://github.com/PyTorchLightning/pytorch-lightning/issues/10638))
|
||||
|
||||
|
||||
- Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))
|
||||
|
||||
|
|
|
@ -22,8 +22,13 @@ from deprecate import void
|
|||
from pytorch_lightning.loops.base import Loop
|
||||
from pytorch_lightning.loops.utilities import _update_dataloader_iter
|
||||
from pytorch_lightning.trainer.progress import BatchProgress
|
||||
from pytorch_lightning.utilities.auto_restart import _reload_dataloader_state_dict, MergedIteratorState
|
||||
from pytorch_lightning.utilities.auto_restart import (
|
||||
_collect_states_on_rank_zero_over_collection,
|
||||
_reload_dataloader_state_dict,
|
||||
MergedIteratorState,
|
||||
)
|
||||
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
|
||||
|
||||
|
@ -173,12 +178,16 @@ class EvaluationEpochLoop(Loop):
|
|||
state_to_save = "state" if self._has_completed() else "previous_state"
|
||||
state: Optional[MergedIteratorState] = getattr(self._data_fetcher.dataloader_iter, state_to_save, None)
|
||||
if state:
|
||||
state_dict["dataloader_state_dict"] = asdict(state)
|
||||
state_dict["dataloader_state_dict"] = _collect_states_on_rank_zero_over_collection(asdict(state))
|
||||
return state_dict
|
||||
|
||||
def on_load_checkpoint(self, state_dict: Dict) -> None:
|
||||
# cache the dataloader state dict until the dataloader objects are available
|
||||
self._dataloader_state_dict = state_dict.get("dataloader_state_dict")
|
||||
# dataset states are collected across all ranks
|
||||
dataloader_state_dict = state_dict.get("dataloader_state_dict", None)
|
||||
if not _fault_tolerant_training() or not dataloader_state_dict:
|
||||
return
|
||||
self._dataloader_state_dict = dataloader_state_dict[self.trainer.global_rank]
|
||||
|
||||
def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher):
|
||||
if not self.trainer.sanity_checking and self._dataloader_state_dict:
|
||||
|
|
|
@ -25,6 +25,7 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import ResultC
|
|||
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.auto_restart import _collect_states_on_rank_zero_over_collection
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
|
@ -320,8 +321,9 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
or self.batch_progress.current.ready == 0 # did not start
|
||||
):
|
||||
return state_dict
|
||||
state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(
|
||||
has_completed=self._has_completed()
|
||||
|
||||
state_dict["dataloader_state_dict"] = _collect_states_on_rank_zero_over_collection(
|
||||
self.trainer.train_dataloader.state_dict(has_completed=self._has_completed())
|
||||
)
|
||||
return state_dict
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from weakref import proxy
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
from pytorch_lightning.utilities.auto_restart import _teardown_dataloader_get_iterators
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.fetching import (
|
||||
AbstractDataFetcher,
|
||||
|
@ -254,6 +255,7 @@ class DataConnector:
|
|||
if self.sanity_check_data_fetcher:
|
||||
self.sanity_check_data_fetcher.teardown()
|
||||
self.sanity_check_data_fetcher = None
|
||||
_teardown_dataloader_get_iterators()
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -29,6 +29,7 @@ from pytorch_lightning.utilities.auto_restart import (
|
|||
patch_dataloader_iterator,
|
||||
)
|
||||
from pytorch_lightning.utilities.data import get_len
|
||||
from pytorch_lightning.utilities.distributed import distributed_available
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
|
||||
|
@ -403,6 +404,10 @@ class CombinedLoader:
|
|||
if isinstance(dataloader, CycleIterator):
|
||||
dataloader = dataloader_to_iter_on.loader
|
||||
|
||||
# dataset states are collected across all ranks
|
||||
rank = torch.distributed.get_rank() if distributed_available() else 0
|
||||
state_dict = state_dict[rank]
|
||||
|
||||
_reload_dataloader_state_dict(dataloader, state_dict)
|
||||
|
||||
# We finally spawned the workers if any.
|
||||
|
|
|
@ -2100,10 +2100,17 @@ class Trainer(
|
|||
return active_loop._results
|
||||
|
||||
def _exit_gracefully_on_signal(self) -> None:
|
||||
if _fault_tolerant_training() and self._terminate_gracefully:
|
||||
caller = inspect.stack()[1]
|
||||
class_name = caller[0].f_locals["self"].__class__.__name__
|
||||
raise ExitGracefullyException(f"Exiting gracefully on {class_name}:{caller.function}")
|
||||
if not _fault_tolerant_training():
|
||||
return
|
||||
if not self._should_terminated_gracefully():
|
||||
return
|
||||
caller = inspect.stack()[1]
|
||||
class_name = caller[0].f_locals["self"].__class__.__name__
|
||||
raise ExitGracefullyException(f"Exiting gracefully on {class_name}:{caller.function}")
|
||||
|
||||
def _should_terminated_gracefully(self) -> bool:
|
||||
value = torch.tensor(self._terminate_gracefully, device=self.training_type_plugin.root_device)
|
||||
return self.training_type_plugin.reduce(value, reduce_op="sum") > 0
|
||||
|
||||
@property
|
||||
def weights_summary(self) -> Optional[str]:
|
||||
|
|
|
@ -31,6 +31,8 @@ from torch.utils.data.dataloader import (
|
|||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero
|
||||
from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
@ -737,3 +739,14 @@ def _teardown_dataloader_get_iterators() -> None:
|
|||
if get_iterator:
|
||||
DataLoader._get_iterator = get_iterator
|
||||
del DataLoader._ori_get_iterator
|
||||
|
||||
|
||||
def _collect_states_on_rank_zero_over_collection(state_dict: Any, key: str = "state") -> Any:
|
||||
"""This utility collects the state across processes for a collection of state."""
|
||||
|
||||
def fn(state: Dict):
|
||||
if key in state:
|
||||
return _collect_states_on_rank_zero(state)
|
||||
return {k: apply_to_collection(v, Dict, fn) for k, v in state.items()}
|
||||
|
||||
return apply_to_collection(state_dict, Dict, fn)
|
||||
|
|
|
@ -378,7 +378,14 @@ def init_dist_connection(
|
|||
)
|
||||
|
||||
|
||||
def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) -> Optional[Dict[int, Any]]:
|
||||
def _broadcast_object_list(obj: Any, rank: int) -> Any:
|
||||
objects = [obj if torch.distributed.get_rank() == rank else None]
|
||||
torch.distributed.broadcast_object_list(objects, src=rank)
|
||||
return objects[0]
|
||||
|
||||
|
||||
# TODO: Refactor with the Strategy Collectives once finalized.
|
||||
def _collect_states_on_rank_zero(state: Dict[str, Any]) -> Dict[int, Any]:
|
||||
"""This distributed utility collects dictionary state across all processes.
|
||||
|
||||
Args:
|
||||
|
@ -391,13 +398,4 @@ def _collect_states_on_rank_zero(state: Dict[str, Any], device: torch.device) ->
|
|||
"""
|
||||
if not distributed_available():
|
||||
return {0: state}
|
||||
states = {}
|
||||
current_rank = torch.distributed.get_rank()
|
||||
for rank in range(1, torch.distributed.get_world_size()):
|
||||
objects = [state if current_rank == rank else None]
|
||||
torch.distributed.broadcast_object_list(objects, src=rank, device=device)
|
||||
states[rank] = objects[0]
|
||||
if current_rank != 0:
|
||||
return None
|
||||
states[0] = state
|
||||
return states
|
||||
return {rank: _broadcast_object_list(state, rank) for rank in range(torch.distributed.get_world_size())}
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
"""General utilities."""
|
||||
import importlib
|
||||
import operator
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from importlib.util import find_spec
|
||||
|
@ -111,4 +110,6 @@ else:
|
|||
|
||||
# experimental feature within PyTorch Lightning.
|
||||
def _fault_tolerant_training() -> bool:
|
||||
return bool(int(os.getenv("PL_FAULT_TOLERANT_TRAINING", 0)))
|
||||
from pytorch_lightning.utilities.enums import _FaultTolerantMode
|
||||
|
||||
return _FaultTolerantMode.detect_current_mode().is_enabled
|
||||
|
|
|
@ -39,6 +39,7 @@ from pytorch_lightning import Callback, LightningModule, seed_everything, Traine
|
|||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities.auto_restart import (
|
||||
_add_capture_metadata_collate,
|
||||
_collect_states_on_rank_zero_over_collection,
|
||||
_MultiProcessingDataLoaderIterStateful,
|
||||
_patch_dataloader_get_iterators,
|
||||
_reload_dataloader_state_dict,
|
||||
|
@ -1254,6 +1255,13 @@ class StatefulRandomDataset(RandomDataset):
|
|||
self.counter = state_dict[0]["counter"]
|
||||
|
||||
|
||||
def test_collect_states_with_collection():
|
||||
state = {"state": 0}
|
||||
collection = [{"a": state, "b": [{"a": state}]}]
|
||||
generated = _collect_states_on_rank_zero_over_collection(collection)
|
||||
assert generated == [{"a": {0: state}, "b": [{"a": {0: state}}]}]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_workers", [0])
|
||||
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"})
|
||||
def test_stateful_workers(num_workers):
|
||||
|
|
|
@ -64,15 +64,14 @@ def test_rank_zero_none_set(rank_key, rank):
|
|||
def _test_collect_states(rank, world_size):
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
|
||||
torch.cuda.set_device(f"cuda:{rank}")
|
||||
|
||||
# initialize the process group
|
||||
torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
|
||||
state = {"something": torch.tensor([rank])}
|
||||
collected_state = _collect_states_on_rank_zero(state, device=torch.device(f"cuda:{rank}"))
|
||||
if rank == 0:
|
||||
assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}}
|
||||
else:
|
||||
assert collected_state is None
|
||||
collected_state = _collect_states_on_rank_zero(state)
|
||||
assert collected_state == {1: {"something": torch.tensor([1])}, 0: {"something": torch.tensor([0])}}
|
||||
|
||||
|
||||
@RunIf(skip_windows=True, min_gpus=2, min_torch="1.10")
|
||||
|
|
Loading…
Reference in New Issue