From d418cf23b2438c2ec8d8f65c499828bc68653b7e Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 24 Mar 2022 02:40:34 -0700 Subject: [PATCH] Do not configure launcher if processes are launched externally (#12431) --- pytorch_lightning/strategies/ddp.py | 2 +- tests/strategies/test_ddp_strategy.py | 44 ++++++++++++++++++++++++++- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/strategies/ddp.py b/pytorch_lightning/strategies/ddp.py index 732a6272c7..ad67467d02 100644 --- a/pytorch_lightning/strategies/ddp.py +++ b/pytorch_lightning/strategies/ddp.py @@ -137,8 +137,8 @@ class DDPStrategy(ParallelStrategy): return self._process_group_backend def _configure_launcher(self) -> None: - self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) if not self.cluster_environment.creates_processes_externally: + self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) self._rank_0_will_call_children_scripts = True def setup_environment(self) -> None: diff --git a/tests/strategies/test_ddp_strategy.py b/tests/strategies/test_ddp_strategy.py index e1ed780275..93c9d1072a 100644 --- a/tests/strategies/test_ddp_strategy.py +++ b/tests/strategies/test_ddp_strategy.py @@ -19,7 +19,7 @@ import torch from torch.nn.parallel import DistributedDataParallel from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.plugins.environments import LightningEnvironment +from pytorch_lightning.plugins.environments import ClusterEnvironment, LightningEnvironment from pytorch_lightning.strategies import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn from tests.helpers.boring_model import BoringModel @@ -147,3 +147,45 @@ def test_ddp_dont_configure_sync_batchnorm(trainer_fn): trainer.strategy.setup(trainer) # because TrainerFn is not FITTING, model is not configured with sync batchnorm assert not isinstance(trainer.strategy.model.layer, torch.nn.modules.batchnorm.SyncBatchNorm) + + +def test_configure_launcher_create_processes_externally(): + class MyClusterEnvironment(ClusterEnvironment): + @property + def creates_processes_externally(self): + return True + + @property + def main_address(self): + return "" + + @property + def main_port(self): + return 8080 + + @staticmethod + def detect(): + return True + + def world_size(self): + return 1 + + def set_world_size(self): + pass + + def global_rank(self): + return 0 + + def set_global_rank(self): + pass + + def local_rank(self): + return 0 + + def node_rank(self): + return 0 + + ddp_strategy = DDPStrategy(cluster_environment=MyClusterEnvironment()) + assert ddp_strategy.launcher is None + ddp_strategy._configure_launcher() + assert ddp_strategy.launcher is None