Add timeout to DeepSpeedStrategy (#20474)
* allow user to pass kwargs to DeepSpeedStrategy * Update deepspeed.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update deepspeed.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make timeout explicit in DeepSpeedStrategy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
This commit is contained in:
parent
1c4612e564
commit
9983f3a9ea
|
@ -18,6 +18,7 @@ import os
|
|||
import platform
|
||||
from collections.abc import Mapping
|
||||
from contextlib import AbstractContextManager, ExitStack
|
||||
from datetime import timedelta
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
@ -29,6 +30,7 @@ from torch.optim import Optimizer
|
|||
from typing_extensions import override
|
||||
|
||||
from lightning.fabric.accelerators import Accelerator, CUDAAccelerator
|
||||
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
|
||||
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from lightning.fabric.plugins.precision import Precision
|
||||
from lightning.fabric.strategies.ddp import DDPStrategy
|
||||
|
@ -97,6 +99,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
load_full_weights: bool = False,
|
||||
precision: Optional[Precision] = None,
|
||||
process_group_backend: Optional[str] = None,
|
||||
timeout: Optional[timedelta] = default_pg_timeout,
|
||||
) -> None:
|
||||
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
|
||||
billion parameter models. `For more information: https://pytorch-
|
||||
|
@ -241,6 +244,7 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
process_group_backend=process_group_backend,
|
||||
)
|
||||
self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally
|
||||
self._timeout: Optional[timedelta] = timeout
|
||||
|
||||
self.config = self._load_config(config)
|
||||
if self.config is None:
|
||||
|
@ -648,7 +652,9 @@ class DeepSpeedStrategy(DDPStrategy, _Sharded):
|
|||
f"MEMBER: {self.global_rank + 1}/{self.world_size}"
|
||||
)
|
||||
self._process_group_backend = self._get_process_group_backend()
|
||||
deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
|
||||
deepspeed.init_distributed(
|
||||
self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout
|
||||
)
|
||||
|
||||
def _set_node_environment_variables(self) -> None:
|
||||
assert self.cluster_environment is not None
|
||||
|
|
|
@ -19,6 +19,7 @@ import platform
|
|||
from collections import OrderedDict
|
||||
from collections.abc import Generator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
|
@ -30,6 +31,7 @@ from typing_extensions import override
|
|||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.plugins import ClusterEnvironment
|
||||
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
|
||||
from lightning.fabric.strategies import _StrategyRegistry
|
||||
from lightning.fabric.strategies.deepspeed import (
|
||||
_DEEPSPEED_AVAILABLE,
|
||||
|
@ -119,6 +121,7 @@ class DeepSpeedStrategy(DDPStrategy):
|
|||
load_full_weights: bool = False,
|
||||
precision_plugin: Optional[Precision] = None,
|
||||
process_group_backend: Optional[str] = None,
|
||||
timeout: Optional[timedelta] = default_pg_timeout,
|
||||
) -> None:
|
||||
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
|
||||
billion parameter models. `For more information: https://pytorch-
|
||||
|
@ -264,6 +267,7 @@ class DeepSpeedStrategy(DDPStrategy):
|
|||
precision_plugin=precision_plugin,
|
||||
process_group_backend=process_group_backend,
|
||||
)
|
||||
self._timeout: Optional[timedelta] = timeout
|
||||
|
||||
self.config = self._load_config(config)
|
||||
if self.config is None:
|
||||
|
@ -364,7 +368,9 @@ class DeepSpeedStrategy(DDPStrategy):
|
|||
f"MEMBER: {self.global_rank + 1}/{self.world_size}"
|
||||
)
|
||||
self._process_group_backend = self._get_process_group_backend()
|
||||
deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
|
||||
deepspeed.init_distributed(
|
||||
self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout
|
||||
)
|
||||
|
||||
def _set_node_environment_variables(self) -> None:
|
||||
assert self.cluster_environment is not None
|
||||
|
|
Loading…
Reference in New Issue