Remove deadlock detection / process reconciliation logic (#16204)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Adrian Wälchli 2023-01-18 13:37:57 +01:00 committed by Luca Antiga
parent 172be3653f
commit da675d69bf
9 changed files with 5 additions and 145 deletions

View File

@ -56,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed the deprecated `pytorch_lightning.profiler` module ([#16359](https://github.com/Lightning-AI/lightning/pull/16359))
- Removed deadlock detection / process reconciliation (`PL_RECONCILE_PROCESS=1`) ([#16204](https://github.com/Lightning-AI/lightning/pull/16204))
- Removed the deprecated `LightningCLI` arguments ([#16380](https://github.com/Lightning-AI/lightning/pull/16380))
* save_config_filename
* save_config_overwrite

View File

@ -178,10 +178,6 @@ class BaguaStrategy(DDPStrategy):
os.environ["LOCAL_RANK"] = str(self.local_rank)
def setup(self, trainer: "pl.Trainer") -> None:
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
if self._should_run_deadlock_detection():
self._share_information_to_prevent_deadlock()
assert self.accelerator is not None
self.accelerator.setup(trainer)

View File

@ -12,13 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
import signal
import tempfile
import time
from datetime import timedelta
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
import torch
@ -52,8 +46,7 @@ from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep
if _FAIRSCALE_AVAILABLE:
@ -101,9 +94,6 @@ class DDPStrategy(ParallelStrategy):
self._ddp_comm_wrapper = ddp_comm_wrapper
self._model_averaging_period = model_averaging_period
self._model_averager: Optional[ModelAverager] = None
self._pids: List[int] = []
self._sync_dir: Optional[str] = None
self._rank_0_will_call_children_scripts: bool = False
self._process_group_backend: Optional[str] = process_group_backend
self._timeout: Optional[timedelta] = timeout
@ -145,18 +135,12 @@ class DDPStrategy(ParallelStrategy):
assert self.cluster_environment is not None
if not self.cluster_environment.creates_processes_externally:
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
self._rank_0_will_call_children_scripts = True
def setup_environment(self) -> None:
self.setup_distributed()
super().setup_environment()
def setup(self, trainer: "pl.Trainer") -> None:
# share ddp pids to all processes
self._rank_0_will_call_children_scripts = bool(self.broadcast(self._rank_0_will_call_children_scripts))
if self._should_run_deadlock_detection():
self._share_information_to_prevent_deadlock()
assert self.accelerator is not None
self.accelerator.setup(trainer)
@ -391,73 +375,6 @@ class DDPStrategy(ParallelStrategy):
description=f"{cls.__class__.__name__}",
)
def _should_run_deadlock_detection(self) -> bool:
"""Determines whether the plugin will perform process reconciliation in case of errors.
If the environment variable `PL_RECONCILE_PROCESS` is set, run detection regardless of the cluster environment.
By default this is disabled. Otherwise, if the cluster environment creates the processes, allow the scheduler /
parent process to perform the process termination, external to Lightning.
"""
return os.getenv("PL_RECONCILE_PROCESS", "0") == "1" or self._rank_0_will_call_children_scripts
def _share_information_to_prevent_deadlock(self) -> None:
self._share_pids()
# there should be a unique sync_dir per nodes.
if self.local_rank == 0:
# create a temporary directory used to synchronize processes on deadlock.
self._sync_dir = tempfile.mkdtemp()
sync_dirs = []
global_node_rank_zero = 0
for _ in range(self.num_nodes):
sync_dirs.append(self.broadcast(self._sync_dir, global_node_rank_zero))
global_node_rank_zero += self.world_size // self.num_nodes
self._sync_dir = sync_dirs[self.node_rank]
def _share_pids(self) -> None:
"""Make all DDP processes aware of all processes pids."""
self.barrier()
pids = self.all_gather(torch.tensor(os.getpid(), device=self.root_device))
pids = pids.cpu().numpy().tolist()
self._pids = pids if isinstance(pids, list) else [pids]
def reconciliate_processes(self, trace: str) -> None:
if self.world_size < 2:
return
if not self._should_run_deadlock_detection():
return
sync_dir = self._sync_dir
if not sync_dir:
rank_zero_warn("Error handling mechanism for deadlock detection is uninitialized. Skipping check.")
return
# The cluster may be configured to periodically purge the `/tmp`
# directory, in which case `sync_dir` may not exist anymore at this
# point. Idempotently create it to ensure its existence.
Path(sync_dir).mkdir(parents=True, exist_ok=True)
# save a file locally.
torch.save(True, os.path.join(sync_dir, f"{self.global_rank}.pl"))
# sleep for a short time
time.sleep(3)
# return if all processes wrote a file in the `sync_dir`.
# todo (tchaton) Add support for non-shared file-system which will fail.
if len(os.listdir(sync_dir)) == (self.world_size // self.num_nodes):
return
for pid in self._pids:
if pid != os.getpid():
os.kill(pid, signal.SIGKILL)
shutil.rmtree(sync_dir)
raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")
def teardown(self) -> None:
log.detail(f"{self.__class__.__name__}: tearing down strategy")

View File

@ -145,7 +145,6 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
self.cpu_offload = _init_cpu_offload(cpu_offload)
self.backward_prefetch = backward_prefetch
self.mixed_precision = mixed_precision
self._rank_0_will_call_children_scripts: bool = False
if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13:
raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`")
activation_checkpointing = activation_checkpointing or []
@ -215,7 +214,6 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
assert self.cluster_environment is not None
if not self.cluster_environment.creates_processes_externally:
self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes)
self._rank_0_will_call_children_scripts = True
def _setup_model(self, model: torch.nn.Module) -> FullyShardedDataParallel:
"""Wraps the model into a
@ -248,8 +246,6 @@ class DDPFullyShardedNativeStrategy(ParallelStrategy):
def setup(self, trainer: "pl.Trainer") -> None:
assert self.accelerator is not None
self.accelerator.setup(trainer)
# share ddp pids to all processes
self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts)
if trainer.state.fn == TrainerFn.FITTING and self._layer_sync:
assert self.model is not None

View File

@ -82,9 +82,6 @@ class ParallelStrategy(Strategy, ABC):
rank=self.global_rank,
)
def reconciliate_processes(self, trace: str) -> None:
"""Function to re-conciliate processes on failure."""
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
"""Perform a all_gather on all processes."""
return _all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)

View File

@ -60,11 +60,6 @@ class DDPShardedStrategy(DDPStrategy):
return super().connect(model)
def setup(self, trainer: "pl.Trainer") -> None:
# share ddp pids to all processes
self._rank_0_will_call_children_scripts: bool = self.broadcast(self._rank_0_will_call_children_scripts)
if self._should_run_deadlock_detection():
self._share_information_to_prevent_deadlock()
assert self.accelerator is not None
self.accelerator.setup(trainer)

View File

@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
from typing import Any, Callable
import pytorch_lightning as pl
from lightning_fabric.utilities.distributed import _distributed_available
from pytorch_lightning.trainer.states import TrainerStatus
from pytorch_lightning.utilities.exceptions import _TunerExitException
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
@ -54,9 +52,6 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg
logger.finalize("failed")
except BaseException as exception:
trainer.state.status = TrainerStatus.INTERRUPTED
if _distributed_available() and trainer.world_size > 1:
# try syncing remaining processes, kill otherwise
trainer.strategy.reconciliate_processes(traceback.format_exc())
trainer._call_callback_hooks("on_exception", exception)
for logger in trainer.loggers:
logger.finalize("failed")

View File

@ -15,10 +15,6 @@
from lightning_fabric.utilities.exceptions import MisconfigurationException # noqa: F401
class DeadlockDetectedException(Exception):
"""Exception used when a deadlock has been detected and processes are being killed."""
class ExitGracefullyException(SystemExit):
"""Exception used when a ``signal.SIGTERM`` is sent to the process.

View File

@ -61,7 +61,7 @@ from pytorch_lightning.strategies import (
SingleDeviceStrategy,
)
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
from tests_pytorch.conftest import mock_cuda_count, mock_mps_count
from tests_pytorch.helpers.datamodules import ClassifDataModule
@ -1803,41 +1803,6 @@ def test_exception_when_lightning_module_is_not_set_on_trainer():
trainer.predict()
class CustomException(Exception):
pass
@RunIf(min_cuda_gpus=2, standalone=True)
def test_ddp_terminate_when_deadlock_is_detected(tmpdir):
"""Test that DDP kills the remaining processes when only one rank is throwing an exception."""
class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
if batch_idx == 1 and self.trainer.is_global_zero:
# rank 0: raises an exception
# rank 1: continues training but will hang on the next barrier in the training loop
raise CustomException
return super().training_step(batch, batch_idx)
model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=5,
num_sanity_val_steps=0,
accelerator="gpu",
devices=2,
strategy="ddp",
enable_progress_bar=False,
enable_model_summary=False,
)
# simulate random failure in training_step on rank 0
with pytest.raises(DeadlockDetectedException, match="CustomException"):
trainer.fit(model)
@RunIf(min_cuda_gpus=1)
def test_multiple_trainer_constant_memory_allocated(tmpdir):
"""This tests ensures calling the trainer several times reset the memory back to 0."""