diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index 66e10a8cbf..ca0cf3fb91 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -154,7 +154,7 @@ class AcceleratorConnector: accelerator_backend = accelerators.DDPCPUSpawnBackend(self.trainer, nprocs=self.trainer.num_processes) elif self.trainer.distributed_backend == "ddp": - accelerator_backend = accelerators.DDPBackend(self.trainer, mode='ddp') + accelerator_backend = accelerators.DDPBackend(self.trainer) elif self.trainer.use_dp: accelerator_backend = accelerators.DataParallelBackend(self.trainer) diff --git a/pytorch_lightning/accelerators/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py index 5c5aa5a289..bedea6f058 100644 --- a/pytorch_lightning/accelerators/ddp_backend.py +++ b/pytorch_lightning/accelerators/ddp_backend.py @@ -43,30 +43,21 @@ else: class DDPBackend(Accelerator): - def __init__(self, trainer, mode: str = 'ddp'): + def __init__(self, trainer): super().__init__(trainer) self.task_idx = None self._has_spawned_children = False - self.mode = mode self.dist = LightningDistributed() def setup(self, model): - if self.mode == 'ddp': - self.__ddp_script_mode_setup() - elif self.mode == 'slurm_ddp': - self.__slurm_setup() - elif self.mode == 'torchelastic_ddp': - self.__torchelastic_setup() - + # first track model self.trainer.model = model - def __slurm_setup(self): - self.task_idx = int(os.environ['SLURM_LOCALID']) + # start the other scripts + self._call_children_scripts() - def __torchelastic_setup(self): - self.task_idx = int(os.environ['LOCAL_RANK']) + def _call_children_scripts(self): - def __ddp_script_mode_setup(self): assert self.trainer.global_rank == 0 self._check_can_spawn_children() self._has_spawned_children = True @@ -137,12 +128,9 @@ class DDPBackend(Accelerator): def train(self): model = self.trainer.model - if self.mode == 'ddp': - results = self.ddp_train(process_idx=self.task_idx, model=model, is_master=True) - del os.environ['WORLD_SIZE'] - return results - else: - self.ddp_train(process_idx=self.task_idx, model=model) + results = self.ddp_train(process_idx=self.task_idx, model=model, is_master=True) + del os.environ['WORLD_SIZE'] + return results def _check_can_spawn_children(self): if self._has_spawned_children: @@ -288,5 +276,4 @@ class DDPBackend(Accelerator): # clean up memory torch.cuda.empty_cache() - if self.trainer.global_rank == 0: - return results + return results