From 9e61de2063724ab6ff9cde75cba1a59d10ee5208 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 2 Aug 2021 21:48:43 +0200 Subject: [PATCH] Torch Elastic DDP DeadLock bug fix (#8655) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- .azure-pipelines/gpu-tests.yml | 2 +- .../plugins/training_type/ddp.py | 23 ++++++++---- .../environments/torch_elastic_deadlock.py | 37 +++++++++++++++++++ tests/special_tests.sh | 8 ++++ 4 files changed, 61 insertions(+), 9 deletions(-) create mode 100644 tests/plugins/environments/torch_elastic_deadlock.py diff --git a/.azure-pipelines/gpu-tests.yml b/.azure-pipelines/gpu-tests.yml index 5a7bcff3bb..f105e75041 100644 --- a/.azure-pipelines/gpu-tests.yml +++ b/.azure-pipelines/gpu-tests.yml @@ -51,7 +51,7 @@ jobs: - bash: | python -c "fname = 'requirements/extra.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)" pip install fairscale>=0.3.4 - pip install "deepspeed>=0.4.0, !=0.4.4" # FIXME: bug with 0.4.4 + pip install "deepspeed>=0.4.3, !=0.4.4" # FIXME: bug with 0.4.4 pip install . --requirement requirements/devel.txt pip list displayName: 'Install dependencies' diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index a44384e18e..003d567b35 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -179,9 +179,6 @@ class DDPPlugin(ParallelPlugin): os.environ["NODE_RANK"] = str(self.cluster_environment.node_rank()) os.environ["LOCAL_RANK"] = str(self.cluster_environment.local_rank()) - # create a temporary directory used to synchronize processes on deadlock. - os.environ["PL_DDP_SYNC_TMPDIR"] = self._sync_dir = tempfile.mkdtemp() - # Check if the current calling command looked like `python a/b/c.py` or `python -m a.b.c` # See https://docs.python.org/3/reference/import.html#main-spec if __main__.__spec__ is None: # pragma: no-cover @@ -410,8 +407,18 @@ class DDPPlugin(ParallelPlugin): def _share_information_to_prevent_deadlock(self): self._share_pids() - # remove `PL_DDP_SYNC_TMPDIR` from os.environ - self._sync_dir = os.environ.pop("PL_DDP_SYNC_TMPDIR", None) + # there should be a unique sync_dir per nodes. + if self.local_rank == 0: + # create a temporary directory used to synchronize processes on deadlock. + self._sync_dir = tempfile.mkdtemp() + + sync_dirs = [] + global_node_rank_zero = 0 + for _ in range(self.num_nodes): + sync_dirs.append(self.broadcast(self._sync_dir, global_node_rank_zero)) + global_node_rank_zero += self.world_size // self.num_nodes + + self._sync_dir = sync_dirs[self.node_rank] def _share_pids(self): """ @@ -436,11 +443,11 @@ class DDPPlugin(ParallelPlugin): # return if all processes wrote a file in the `sync_dir`. # todo (tchaton) Add support for non-shared file-system which will fail. - if len(os.listdir(sync_dir)) == self.world_size: + if len(os.listdir(sync_dir)) == (self.world_size // self.num_nodes): return for pid in self._pids: if pid != os.getpid(): os.kill(pid, signal.SIGKILL) - shutil.rmtree(sync_dir) - raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}") + shutil.rmtree(sync_dir) + raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}") diff --git a/tests/plugins/environments/torch_elastic_deadlock.py b/tests/plugins/environments/torch_elastic_deadlock.py new file mode 100644 index 0000000000..ac2348285d --- /dev/null +++ b/tests/plugins/environments/torch_elastic_deadlock.py @@ -0,0 +1,37 @@ +import os +import sys +from contextlib import suppress + +from pytorch_lightning import Trainer +from pytorch_lightning.plugins.training_type.ddp import DDPPlugin +from pytorch_lightning.utilities.exceptions import DeadlockDetectedException +from tests.helpers.boring_model import BoringModel + +if os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") == "1": + + class CustomException(Exception): + pass + + class Model(BoringModel): + def training_step(self, batch, batch_idx): + if batch_idx == 1 and self.trainer.is_global_zero: + # rank 0: raises an exception + # rank 1: continues training but will hang on the next barrier in the training loop + raise CustomException + return super().training_step(batch, batch_idx) + + model = Model() + + trainer = Trainer( + default_root_dir=".", max_epochs=1, limit_train_batches=5, num_sanity_val_steps=0, gpus=2, accelerator="ddp" + ) + assert isinstance(trainer.training_type_plugin, DDPPlugin) + + with suppress(DeadlockDetectedException): + # simulate random failure in training_step on rank 0 + trainer.fit(model) + + # used to capture success from this script in the CI. + print("SUCCEEDED") + + sys.exit(0) diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 96d1e3ba4a..2c9397125e 100755 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -78,6 +78,14 @@ if [ $? -eq 0 ]; then report+="Ran\ttests/utilities/test_warnings.py\n" fi +# TODO: enable when CI uses torch>=1.9 +# test deadlock is properly handled with TorchElastic. +# LOGS=$(PL_RUNNING_SPECIAL_TESTS=1 python -m torch.distributed.run --nproc_per_node=2 --max_restarts 0 -m coverage run --source pytorch_lightning -a tests/plugins/environments/torch_elastic_deadlock.py | grep "SUCCEEDED") +# if [ -z "$LOGS" ]; then +# exit 1 +# fi +# report+="Ran\ttests/plugins/environments/torch_elastic_deadlock.py\n" + # test that a user can manually launch individual processes args="--trainer.gpus 2 --trainer.accelerator ddp --trainer.fast_dev_run 1" MASTER_ADDR="localhost" MASTER_PORT=1234 LOCAL_RANK=1 python pl_examples/basic_examples/simple_image_classifier.py ${args} &