Remove deadlock detection / process reconciliation logic (#16204)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
172be3653f
commit
da675d69bf
|
@ -56,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Removed the deprecated `pytorch_lightning.profiler` module ([#16359](https://github.com/Lightning-AI/lightning/pull/16359))
|
||||
|
||||
- Removed deadlock detection / process reconciliation (`PL_RECONCILE_PROCESS=1`) ([#16204](https://github.com/Lightning-AI/lightning/pull/16204))
|
||||
|
||||
|
||||
- Removed the deprecated `LightningCLI` arguments ([#16380](https://github.com/Lightning-AI/lightning/pull/16380))
|
||||
* save_config_filename
|
||||
* save_config_overwrite
|
||||
|
|
|
@ -178,10 +178,6 @@ class BaguaStrategy(DDPStrategy):
|
|||
os.environ["LOCAL_RANK"] = str(self.local_rank)
|
||||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
|
||||
if self._should_run_deadlock_detection():
|
||||
self._share_information_to_prevent_deadlock()
|
||||
|
||||
assert self.accelerator is not None
|
||||
self.accelerator.setup(trainer)
|
||||
|
||||
|
|
|
@ -12,13 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import signal
|
||||
import tempfile
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
@ -52,8 +46,7 @@ from pytorch_lightning.strategies.parallel import ParallelStrategy
|
|||
from pytorch_lightning.strategies.strategy import TBroadcast
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
|
||||
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
|
||||
from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep
|
||||
|
||||
if _FAIRSCALE_AVAILABLE:
|
||||
|
@ -101,9 +94,6 @@ class DDPStrategy(ParallelStrategy):
|
|||
self._ddp_comm_wrapper = ddp_comm_wrapper
|
||||
self._model_averaging_period = model_averaging_period
|
||||
self._model_averager: Optional[ModelAverager] = None
|
||||
self._pids: List[int] = []
|
||||
self._sync_dir: Optional[str] = None
|
||||
self._rank_0_will_call_children_scripts: bool = False
|
||||
self._process_group_backend: Optional[str] = process_group_backend
|
||||
self._timeout: Optional[timedelta] = timeout
|
||||
|
||||
|
@ -145,18 +135,12 @@ class DDPStrategy(ParallelStrategy):
|
|||
assert self.cluster_environment is not None
|
||||
if not self.cluster_environment.creates_processes_externally:
|
||||
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
|
||||
self._rank_0_will_call_children_scripts = True
|
||||
|
||||
def setup_environment(self) -> None:
|
||||
self.setup_distributed()
|
||||
super().setup_environment()
|
||||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
# share ddp pids to all processes
|
||||
self._rank_0_will_call_children_scripts = bool(self.broadcast(self._rank_0_will_call_children_scripts))
|
||||
if self._should_run_deadlock_detection():
|
||||
self._share_information_to_prevent_deadlock()
|
||||
|
||||
assert self.accelerator is not None
|
||||
self.accelerator.setup(trainer)
|
||||
|
||||
|
@ -391,73 +375,6 @@ class DDPStrategy(ParallelStrategy):
|
|||
description=f"{cls.__class__.__name__}",
|
||||
)
|
||||
|
||||
def _should_run_deadlock_detection(self) -> bool:
|
||||
"""Determines whether the plugin will perform process reconciliation in case of errors.
|
||||
|
||||
If the environment variable `PL_RECONCILE_PROCESS` is set, run detection regardless of the cluster environment.
|
||||
By default this is disabled. Otherwise, if the cluster environment creates the processes, allow the scheduler /
|
||||
parent process to perform the process termination, external to Lightning.
|
||||
"""
|
||||
return os.getenv("PL_RECONCILE_PROCESS", "0") == "1" or self._rank_0_will_call_children_scripts
|
||||
|
||||
def _share_information_to_prevent_deadlock(self) -> None:
|
||||
self._share_pids()
|
||||
|
||||
# there should be a unique sync_dir per nodes.
|
||||
if self.local_rank == 0:
|
||||
# create a temporary directory used to synchronize processes on deadlock.
|
||||
self._sync_dir = tempfile.mkdtemp()
|
||||
|
||||
sync_dirs = []
|
||||
global_node_rank_zero = 0
|
||||
for _ in range(self.num_nodes):
|
||||
sync_dirs.append(self.broadcast(self._sync_dir, global_node_rank_zero))
|
||||
global_node_rank_zero += self.world_size // self.num_nodes
|
||||
|
||||
self._sync_dir = sync_dirs[self.node_rank]
|
||||
|
||||
def _share_pids(self) -> None:
|
||||
"""Make all DDP processes aware of all processes pids."""
|
||||
self.barrier()
|
||||
pids = self.all_gather(torch.tensor(os.getpid(), device=self.root_device))
|
||||
pids = pids.cpu().numpy().tolist()
|
||||
self._pids = pids if isinstance(pids, list) else [pids]
|
||||
|
||||
def reconciliate_processes(self, trace: str) -> None:
|
||||
if self.world_size < 2:
|
||||
return
|
||||
|
||||
if not self._should_run_deadlock_detection():
|
||||
return
|
||||
|
||||
sync_dir = self._sync_dir
|
||||
|
||||
if not sync_dir:
|
||||
rank_zero_warn("Error handling mechanism for deadlock detection is uninitialized. Skipping check.")
|
||||
return
|
||||
|
||||
# The cluster may be configured to periodically purge the `/tmp`
|
||||
# directory, in which case `sync_dir` may not exist anymore at this
|
||||
# point. Idempotently create it to ensure its existence.
|
||||
Path(sync_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# save a file locally.
|
||||
torch.save(True, os.path.join(sync_dir, f"{self.global_rank}.pl"))
|
||||
|
||||
# sleep for a short time
|
||||
time.sleep(3)
|
||||
|
||||
# return if all processes wrote a file in the `sync_dir`.
|
||||
# todo (tchaton) Add support for non-shared file-system which will fail.
|
||||
if len(os.listdir(sync_dir)) == (self.world_size // self.num_nodes):
|
||||
return
|
||||
|
||||
for pid in self._pids:
|
||||
if pid != os.getpid():
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
shutil.rmtree(sync_dir)
|
||||
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")
|
||||
|
||||
def teardown(self) -> None:
|
||||
log.detail(f"{self.__class__.__name__}: tearing down strategy")
|
||||
|
||||
|
|
|
@ -145,7 +145,6 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
self.cpu_offload = _init_cpu_offload(cpu_offload)
|
||||
self.backward_prefetch = backward_prefetch
|
||||
self.mixed_precision = mixed_precision
|
||||
self._rank_0_will_call_children_scripts: bool = False
|
||||
if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13:
|
||||
raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`")
|
||||
activation_checkpointing = activation_checkpointing or []
|
||||
|
@ -215,7 +214,6 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
assert self.cluster_environment is not None
|
||||
if not self.cluster_environment.creates_processes_externally:
|
||||
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
|
||||
self._rank_0_will_call_children_scripts = True
|
||||
|
||||
def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
|
||||
"""Wraps the model into a
|
||||
|
@ -248,8 +246,6 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
|
|||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
assert self.accelerator is not None
|
||||
self.accelerator.setup(trainer)
|
||||
# share ddp pids to all processes
|
||||
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
|
||||
|
||||
if trainer.state.fn == TrainerFn.FITTING and self._layer_sync:
|
||||
assert self.model is not None
|
||||
|
|
|
@ -82,9 +82,6 @@ class ParallelStrategy(Strategy, ABC):
|
|||
rank=self.global_rank,
|
||||
)
|
||||
|
||||
def reconciliate_processes(self, trace: str) -> None:
|
||||
"""Function to re-conciliate processes on failure."""
|
||||
|
||||
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
|
||||
"""Perform a all_gather on all processes."""
|
||||
return _all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
|
||||
|
|
|
@ -60,11 +60,6 @@ class DDPShardedStrategy(DDPStrategy):
|
|||
return super().connect(model)
|
||||
|
||||
def setup(self, trainer: "pl.Trainer") -> None:
|
||||
# share ddp pids to all processes
|
||||
self._rank_0_will_call_children_scripts: bool = self.broadcast(self._rank_0_will_call_children_scripts)
|
||||
if self._should_run_deadlock_detection():
|
||||
self._share_information_to_prevent_deadlock()
|
||||
|
||||
assert self.accelerator is not None
|
||||
self.accelerator.setup(trainer)
|
||||
|
||||
|
|
|
@ -11,11 +11,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import traceback
|
||||
from typing import Any, Callable
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_fabric.utilities.distributed import _distributed_available
|
||||
from pytorch_lightning.trainer.states import TrainerStatus
|
||||
from pytorch_lightning.utilities.exceptions import _TunerExitException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
|
@ -54,9 +52,6 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg
|
|||
logger.finalize("failed")
|
||||
except BaseException as exception:
|
||||
trainer.state.status = TrainerStatus.INTERRUPTED
|
||||
if _distributed_available() and trainer.world_size > 1:
|
||||
# try syncing remaining processes, kill otherwise
|
||||
trainer.strategy.reconciliate_processes(traceback.format_exc())
|
||||
trainer._call_callback_hooks("on_exception", exception)
|
||||
for logger in trainer.loggers:
|
||||
logger.finalize("failed")
|
||||
|
|
|
@ -15,10 +15,6 @@
|
|||
from lightning_fabric.utilities.exceptions import MisconfigurationException # noqa: F401
|
||||
|
||||
|
||||
class DeadlockDetectedException(Exception):
|
||||
"""Exception used when a deadlock has been detected and processes are being killed."""
|
||||
|
||||
|
||||
class ExitGracefullyException(SystemExit):
|
||||
"""Exception used when a ``signal.SIGTERM`` is sent to the process.
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ from pytorch_lightning.strategies import (
|
|||
SingleDeviceStrategy,
|
||||
)
|
||||
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
|
||||
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
|
||||
from tests_pytorch.conftest import mock_cuda_count, mock_mps_count
|
||||
from tests_pytorch.helpers.datamodules import ClassifDataModule
|
||||
|
@ -1803,41 +1803,6 @@ def test_exception_when_lightning_module_is_not_set_on_trainer():
|
|||
trainer.predict()
|
||||
|
||||
|
||||
class CustomException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=2, standalone=True)
|
||||
def test_ddp_terminate_when_deadlock_is_detected(tmpdir):
|
||||
"""Test that DDP kills the remaining processes when only one rank is throwing an exception."""
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
if batch_idx == 1 and self.trainer.is_global_zero:
|
||||
# rank 0: raises an exception
|
||||
# rank 1: continues training but will hang on the next barrier in the training loop
|
||||
raise CustomException
|
||||
return super().training_step(batch, batch_idx)
|
||||
|
||||
model = TestModel()
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_train_batches=5,
|
||||
num_sanity_val_steps=0,
|
||||
accelerator="gpu",
|
||||
devices=2,
|
||||
strategy="ddp",
|
||||
enable_progress_bar=False,
|
||||
enable_model_summary=False,
|
||||
)
|
||||
|
||||
# simulate random failure in training_step on rank 0
|
||||
with pytest.raises(DeadlockDetectedException, match="CustomException"):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(min_cuda_gpus=1)
|
||||
def test_multiple_trainer_constant_memory_allocated(tmpdir):
|
||||
"""This tests ensures calling the trainer several times reset the memory back to 0."""
|
||||
|
|
Loading…
Reference in New Issue