diff --git a/pytorch_lightning/accelerator_backends/__init__.py b/pytorch_lightning/accelerator_backends/__init__.py index b56542a4f7..d56ca53d06 100644 --- a/pytorch_lightning/accelerator_backends/__init__.py +++ b/pytorch_lightning/accelerator_backends/__init__.py @@ -3,3 +3,5 @@ from pytorch_lightning.accelerator_backends.tpu_backend import TPUBackend from pytorch_lightning.accelerator_backends.dp_backend import DataParallelBackend from pytorch_lightning.accelerator_backends.ddp_spawn_backend import DDPSpawnBackend from pytorch_lightning.accelerator_backends.cpu_backend import CPUBackend +from pytorch_lightning.accelerator_backends.ddp_backend import DDPBackend +from pytorch_lightning.accelerator_backends.ddp2_backend import DDP2Backend diff --git a/pytorch_lightning/accelerator_backends/ddp2_backend.py b/pytorch_lightning/accelerator_backends/ddp2_backend.py new file mode 100644 index 0000000000..cc14c44ebe --- /dev/null +++ b/pytorch_lightning/accelerator_backends/ddp2_backend.py @@ -0,0 +1,160 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import os +import torch +from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning import _logger as log +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +try: + from hydra.utils import to_absolute_path, get_original_cwd + from hydra.core.hydra_config import HydraConfig +except ImportError: + HYDRA_AVAILABLE = False +else: + HYDRA_AVAILABLE = True + +try: + from apex import amp +except ImportError: + APEX_AVAILABLE = False +else: + APEX_AVAILABLE = True + + +class DDP2Backend(object): + + def __init__(self, trainer): + self.trainer = trainer + self.task_idx = None + + def setup(self): + self._resolve_task_idx() + + def _resolve_task_idx(self): + if self.trainer.is_slurm_managing_tasks: + self.task_idx = int(os.environ['SLURM_LOCALID']) + else: + # torchelastic or general non_slurm ddp2 + try: + self.task_idx = int(os.environ['LOCAL_RANK']) + except Exception as e: + m = 'ddp2 only works in SLURM or via torchelastic with the WORLD_SIZE, LOCAL_RANK, GROUP_RANK flags' + raise MisconfigurationException(m) + + def train(self, model): + self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model) + + def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0): + """ + Entry point for ddp + + Args: + process_idx: + mp_queue: multiprocessing queue + model: + is_master: + proc_offset: + + Returns: + + """ + # offset the process id if requested + process_idx = process_idx + proc_offset + + # show progressbar only on progress_rank 0 + if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: + self.trainer.progress_bar_callback.disable() + + self.trainer.local_rank = self.trainer.node_rank + self.trainer.global_rank = self.trainer.node_rank + self.trainer.world_size = self.trainer.num_nodes + + # set warning rank + rank_zero_only.rank = self.trainer.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + model.trainer = self.trainer + model.init_ddp_connection( + self.trainer.global_rank, + self.trainer.world_size, + self.trainer.is_slurm_managing_tasks + ) + + # call setup after the ddp process has connected + self.trainer.call_setup_hook(model) + + # on world_size=0 let everyone know training is starting + if self.trainer.is_global_zero: + log.info('-' * 100) + log.info(f'distributed_backend={self.trainer.distributed_backend}') + log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') + log.info('-' * 100) + + # CHOOSE OPTIMIZER + # allow for lr schedulers as well + optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model) + self.trainer.optimizers = optimizers + self.trainer.lr_schedulers = lr_schedulers + self.trainer.optimizer_frequencies = optimizer_frequencies + + # MODEL + # copy model to each gpu + if self.trainer.on_gpu: + gpu_idx = process_idx + + # when using ddp, the master process (proc 0) continues running as the main one + # this means that the local rank will always be 0 + # (even if cuda visible devices has other visible gpus) + # this means that the master process needs to pull the 0th visible index as the device number + if is_master: + available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') + gpu_idx = int(available_gpus[self.trainer.local_rank]) + + self.trainer.root_gpu = gpu_idx + torch.cuda.set_device(self.trainer.root_gpu) + model.cuda(self.trainer.root_gpu) + + # set model properties before going into wrapper + self.trainer.copy_trainer_model_properties(model) + + # AMP + # run through amp wrapper before going to distributed DP + # TODO: remove with dropping NVIDIA AMP support + if self.trainer.use_amp and not NATIVE_AMP_AVALAIBLE: + model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) + self.trainer.optimizers = optimizers + self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers) + + # DDP2 uses all GPUs on the machine + device_ids = self.trainer.data_parallel_device_ids + + # allow user to configure ddp + model = model.configure_ddp(model, device_ids) + + # continue training routine + results = self.trainer.run_pretrain_routine(model) + + # get original model + model = self.trainer.get_model() + + # persist info in ddp_spawn + self.trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) + + # clean up memory + torch.cuda.empty_cache() diff --git a/pytorch_lightning/accelerator_backends/ddp_backend.py b/pytorch_lightning/accelerator_backends/ddp_backend.py new file mode 100644 index 0000000000..0b90a83474 --- /dev/null +++ b/pytorch_lightning/accelerator_backends/ddp_backend.py @@ -0,0 +1,229 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import os +import torch +import subprocess +import sys +from time import sleep +import numpy as np +from os.path import abspath +from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning import _logger as log +from typing import Optional + +try: + from hydra.utils import to_absolute_path, get_original_cwd + from hydra.core.hydra_config import HydraConfig +except ImportError: + HYDRA_AVAILABLE = False +else: + HYDRA_AVAILABLE = True + +try: + from apex import amp +except ImportError: + APEX_AVAILABLE = False +else: + APEX_AVAILABLE = True + + +class DDPBackend(object): + + def __init__(self, trainer): + self.trainer = trainer + self.task_idx = None + + def slurm_setup(self): + self.task_idx = int(os.environ['SLURM_LOCALID']) + + def torchelastic_setup(self): + self.task_idx = int(os.environ['LOCAL_RANK']) + + def train(self, model): + self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model) + + def spawn_ddp_children(self, model): + port = os.environ['MASTER_PORT'] + + master_address = '127.0.0.1' if 'MASTER_ADDR' not in os.environ else os.environ['MASTER_ADDR'] + os.environ['MASTER_PORT'] = f'{port}' + os.environ['MASTER_ADDR'] = f'{master_address}' + + # allow the user to pass the node rank + node_rank = '0' + if 'NODE_RANK' in os.environ: + node_rank = os.environ['NODE_RANK'] + if 'GROUP_RANK' in os.environ: + node_rank = os.environ['GROUP_RANK'] + + os.environ['NODE_RANK'] = node_rank + os.environ['LOCAL_RANK'] = '0' + + # when user is using hydra find the absolute path + path_lib = abspath if not HYDRA_AVAILABLE else to_absolute_path + + # pull out the commands used to run the script and resolve the abs file path + command = sys.argv + try: + full_path = path_lib(command[0]) + except Exception as e: + full_path = abspath(command[0]) + + command[0] = full_path + # use the same python interpreter and actually running + command = [sys.executable] + command + + # since this script sets the visible devices we replace the gpus flag with a number + num_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',').__len__() + + if '--gpus' in command: + gpu_flag_idx = command.index('--gpus') + command[gpu_flag_idx + 1] = f'{num_gpus}' + + os.environ['WORLD_SIZE'] = f'{num_gpus * self.trainer.num_nodes}' + + self.trainer.interactive_ddp_procs = [] + for local_rank in range(1, self.trainer.num_processes): + env_copy = os.environ.copy() + env_copy['LOCAL_RANK'] = f'{local_rank}' + + # start process + # if hydra is available and initialized, make sure to set the cwd correctly + cwd: Optional[str] = None + if HYDRA_AVAILABLE: + if HydraConfig.initialized(): + cwd = get_original_cwd() + proc = subprocess.Popen(command, env=env_copy, cwd=cwd) + self.trainer.interactive_ddp_procs.append(proc) + + # starting all processes at once can cause issues + # with dataloaders delay between 1-10 seconds + delay = np.random.uniform(1, 5, 1)[0] + sleep(delay) + + local_rank = 0 + results = self.ddp_train(local_rank, mp_queue=None, model=model, is_master=True) + del os.environ['WORLD_SIZE'] + + return results + + def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0): + """ + Entry point for ddp + + Args: + process_idx: + mp_queue: multiprocessing queue + model: + is_master: + proc_offset: + + Returns: + + """ + # offset the process id if requested + process_idx = process_idx + proc_offset + + # show progressbar only on progress_rank 0 + if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: + self.trainer.progress_bar_callback.disable() + + # determine which process we are and world size + self.trainer.local_rank = process_idx + self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx + self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes + + # set warning rank + rank_zero_only.rank = self.trainer.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + model.trainer = self.trainer + model.init_ddp_connection( + self.trainer.global_rank, + self.trainer.world_size, + self.trainer.is_slurm_managing_tasks + ) + + # call setup after the ddp process has connected + self.trainer.call_setup_hook(model) + + # on world_size=0 let everyone know training is starting + if self.trainer.is_global_zero: + log.info('-' * 100) + log.info(f'distributed_backend={self.trainer.distributed_backend}') + log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') + log.info('-' * 100) + + # CHOOSE OPTIMIZER + # allow for lr schedulers as well + optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model) + self.trainer.optimizers = optimizers + self.trainer.lr_schedulers = lr_schedulers + self.trainer.optimizer_frequencies = optimizer_frequencies + + # MODEL + # copy model to each gpu + if self.trainer.on_gpu: + gpu_idx = process_idx + + # when using ddp, the master process (proc 0) continues running as the main one + # this means that the local rank will always be 0 + # (even if cuda visible devices has other visible gpus) + # this means that the master process needs to pull the 0th visible index as the device number + if is_master: + available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') + gpu_idx = int(available_gpus[self.trainer.local_rank]) + + self.trainer.root_gpu = gpu_idx + torch.cuda.set_device(self.trainer.root_gpu) + model.cuda(self.trainer.root_gpu) + + # set model properties before going into wrapper + self.trainer.copy_trainer_model_properties(model) + + # AMP + # run through amp wrapper before going to distributed DP + # TODO: remove with dropping NVIDIA AMP support + if self.trainer.use_amp and not NATIVE_AMP_AVALAIBLE: + model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) + self.trainer.optimizers = optimizers + self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers) + + # DDP2 uses all GPUs on the machine + if self.trainer.distributed_backend == 'ddp' or self.trainer.distributed_backend == 'ddp_spawn': + device_ids = [self.trainer.root_gpu] + else: # includes ddp_cpu + device_ids = None + + # allow user to configure ddp + model = model.configure_ddp(model, device_ids) + + # continue training routine + results = self.trainer.run_pretrain_routine(model) + + # get original model + model = self.trainer.get_model() + + # persist info in ddp_spawn + self.trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) + + # clean up memory + torch.cuda.empty_cache() + + if self.trainer.global_rank == 0 and self.trainer.distributed_backend not in ['ddp_spawn', 'ddp_cpu']: + return results diff --git a/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py b/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py index 122355856e..678b1e57c2 100644 --- a/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerator_backends/ddp_spawn_backend.py @@ -60,7 +60,7 @@ class DDPSpawnBackend(object): self.trainer.model = model return results - def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0): + def ddp_train(self, process_idx, mp_queue, model): """ Entry point for ddp @@ -68,15 +68,10 @@ class DDPSpawnBackend(object): process_idx: mp_queue: multiprocessing queue model: - is_master: - proc_offset: Returns: """ - # offset the process id if requested - process_idx = process_idx + proc_offset - # show progressbar only on progress_rank 0 if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: self.trainer.progress_bar_callback.disable() @@ -126,11 +121,6 @@ class DDPSpawnBackend(object): # copy model to each gpu if self.trainer.on_gpu: gpu_idx = process_idx - if is_master: - # source of truth is cuda for gpu idx - gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') - gpu_idx = int(gpus[self.trainer.local_rank]) - self.trainer.root_gpu = gpu_idx torch.cuda.set_device(self.trainer.root_gpu) model.cuda(self.trainer.root_gpu) diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 1386f2ab47..1f92e61e94 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -164,15 +164,6 @@ else: HOROVOD_AVAILABLE = True -try: - from hydra.utils import to_absolute_path, get_original_cwd - from hydra.core.hydra_config import HydraConfig -except ImportError: - HYDRA_AVAILABLE = False -else: - HYDRA_AVAILABLE = True - - try: import torch_xla except ImportError: @@ -205,6 +196,7 @@ class TrainerDDPMixin(ABC): node_rank: int tpu_cores: int testing: bool + global_rank: int datamodule: Optional[LightningDataModule] @property @@ -429,175 +421,6 @@ class TrainerDDPMixin(ABC): os.environ['MASTER_PORT'] = str(default_port) - def spawn_ddp_children(self, model): - port = os.environ['MASTER_PORT'] - - master_address = '127.0.0.1' if 'MASTER_ADDR' not in os.environ else os.environ['MASTER_ADDR'] - os.environ['MASTER_PORT'] = f'{port}' - os.environ['MASTER_ADDR'] = f'{master_address}' - - # allow the user to pass the node rank - node_rank = '0' - if 'NODE_RANK' in os.environ: - node_rank = os.environ['NODE_RANK'] - if 'GROUP_RANK' in os.environ: - node_rank = os.environ['GROUP_RANK'] - - os.environ['NODE_RANK'] = node_rank - os.environ['LOCAL_RANK'] = '0' - - # when user is using hydra find the absolute path - path_lib = abspath if not HYDRA_AVAILABLE else to_absolute_path - - # pull out the commands used to run the script and resolve the abs file path - command = sys.argv - try: - full_path = path_lib(command[0]) - except Exception as e: - full_path = abspath(command[0]) - - command[0] = full_path - # use the same python interpreter and actually running - command = [sys.executable] + command - - # since this script sets the visible devices we replace the gpus flag with a number - num_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',').__len__() - - if '--gpus' in command: - gpu_flag_idx = command.index('--gpus') - command[gpu_flag_idx + 1] = f'{num_gpus}' - - os.environ['WORLD_SIZE'] = f'{num_gpus * self.num_nodes}' - - self.interactive_ddp_procs = [] - for local_rank in range(1, self.num_processes): - env_copy = os.environ.copy() - env_copy['LOCAL_RANK'] = f'{local_rank}' - - # start process - # if hydra is available and initialized, make sure to set the cwd correctly - cwd: Optional[str] = None - if HYDRA_AVAILABLE: - if HydraConfig.initialized(): - cwd = get_original_cwd() - proc = subprocess.Popen(command, env=env_copy, cwd=cwd) - self.interactive_ddp_procs.append(proc) - - # starting all processes at once can cause issues - # with dataloaders delay between 1-10 seconds - delay = np.random.uniform(1, 5, 1)[0] - sleep(delay) - - local_rank = 0 - results = self.ddp_train(local_rank, mp_queue=None, model=model, is_master=True) - del os.environ['WORLD_SIZE'] - - return results - - def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0): - """ - Entry point for ddp - - Args: - process_idx: - mp_queue: multiprocessing queue - model: - is_master: - proc_offset: - - Returns: - - """ - # offset the process id if requested - process_idx = process_idx + proc_offset - - # show progressbar only on progress_rank 0 - if (self.node_rank != 0 or process_idx != 0) and self.progress_bar_callback is not None: - self.progress_bar_callback.disable() - - # determine which process we are and world size - if self.use_ddp: - self.local_rank = process_idx - self.global_rank = self.node_rank * self.num_processes + process_idx - self.world_size = self.num_nodes * self.num_processes - - elif self.use_ddp2: - self.local_rank = self.node_rank - self.global_rank = self.node_rank - self.world_size = self.num_nodes - - # set warning rank - rank_zero_only.rank = self.global_rank - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - model.trainer = self - model.init_ddp_connection(self.global_rank, self.world_size, self.is_slurm_managing_tasks) - - # call setup after the ddp process has connected - self.call_setup_hook(model) - - # on world_size=0 let everyone know training is starting - if self.is_global_zero: - log.info('-' * 100) - log.info(f'distributed_backend={self.distributed_backend}') - log.info(f'All DDP processes registered. Starting ddp with {self.world_size} processes') - log.info('-' * 100) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) - - # MODEL - # copy model to each gpu - if self.on_gpu: - gpu_idx = process_idx - if is_master: - # source of truth is cuda for gpu idx - gpu_idx = self.local_rank - - self.root_gpu = gpu_idx - torch.cuda.set_device(self.root_gpu) - model.cuda(self.root_gpu) - - # set model properties before going into wrapper - self.copy_trainer_model_properties(model) - - # AMP - # run through amp wrapper before going to distributed DP - # TODO: remove with dropping NVIDIA AMP support - if self.use_amp and not NATIVE_AMP_AVALAIBLE: - model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level) - self.optimizers = optimizers - self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers) - - # DDP2 uses all GPUs on the machine - if self.distributed_backend == 'ddp' or self.distributed_backend == 'ddp_spawn': - device_ids = [self.root_gpu] - elif self.use_ddp2: - device_ids = self.data_parallel_device_ids - else: # includes ddp_cpu - device_ids = None - - # allow user to configure ddp - model = model.configure_ddp(model, device_ids) - - # continue training routine - results = self.run_pretrain_routine(model) - - # get original model - model = self.get_model() - - # persist info in ddp_spawn - self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) - - # clean up memory - torch.cuda.empty_cache() - - if self.global_rank == 0 and self.distributed_backend not in ['ddp_spawn', 'ddp_cpu']: - return results - def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): if self.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']: return diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 59a33dad7e..d7903f4156 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -52,7 +52,7 @@ from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.configuration_validator import ConfigValidator from pytorch_lightning.accelerator_backends import ( - GPUBackend, TPUBackend, CPUBackend, DDPSpawnBackend, DataParallelBackend) + GPUBackend, TPUBackend, CPUBackend, DDPSpawnBackend, DataParallelBackend, DDPBackend, DDP2Backend) # warnings to ignore in trainer warnings.filterwarnings( @@ -972,48 +972,54 @@ class Trainer( self._run_lr_finder_internally(model) model.logger = self.logger # reset logger binding - # route to appropriate start method - # when using multi-node or DDP within a node start each module in a separate process + # set testing if set in environ + self.testing = os.environ.get('PL_TESTING_MODE', self.testing) + + # ------------------- + # determine ddp mode + # ------------------- + # SLURM ddp + use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks + + # torchelastic or general non_slurm ddp + te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ) + use_torchelastic_ddp = self.use_ddp and te_flags_passed + + use_ddp_spawn = self.use_ddp and self.distributed_backend in ['ddp_cpu', 'ddp_spawn'] + + # ------------------- + # route training mode + # ------------------- + # DDP2 (cluster only) if self.use_ddp2: - if self.is_slurm_managing_tasks: - task = int(os.environ['SLURM_LOCALID']) + self.accelerator_backend = DDP2Backend(self) + self.accelerator_backend.setup() + self.accelerator_backend.train(model) - # torchelastic or general non_slurm ddp2 - elif 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ): - task = int(os.environ['LOCAL_RANK']) + elif use_slurm_ddp: + self.accelerator_backend = DDPBackend(self) + self.accelerator_backend.slurm_setup() + self.accelerator_backend.train(model) - self.ddp_train(process_idx=task, mp_queue=None, model=model) + elif use_torchelastic_ddp: + self.accelerator_backend = DDPBackend(self) + self.accelerator_backend.torchelastic_setup() + self.accelerator_backend.train(model) - elif self.use_ddp: + # regular ddp using .spawn + elif use_ddp_spawn: + self.accelerator_backend = DDPSpawnBackend(self) + self.accelerator_backend.setup() + self.accelerator_backend.train(model, nprocs=self.num_processes) + results = self.accelerator_backend.teardown(model) - # set testing if set in environ - self.testing = os.environ.get('PL_TESTING_MODE', self.testing) - - if self.is_slurm_managing_tasks: - task = int(os.environ['SLURM_LOCALID']) - self.ddp_train(process_idx=task, mp_queue=None, model=model) - - # torchelastic or general non_slurm ddp - elif 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ): - task = int(os.environ['LOCAL_RANK']) - self.ddp_train(process_idx=task, mp_queue=None, model=model) - - elif self.distributed_backend == 'ddp_cpu': - self.accelerator_backend = DDPSpawnBackend(self) - self.accelerator_backend.setup() - self.accelerator_backend.train(model, nprocs=self.num_processes) - results = self.accelerator_backend.teardown(model) - - elif self.distributed_backend == 'ddp_spawn': - self.accelerator_backend = DDPSpawnBackend(self) - self.accelerator_backend.setup() - self.accelerator_backend.train(model, nprocs=self.num_processes) - results = self.accelerator_backend.teardown(model) - - elif self.distributed_backend == 'ddp': - self.set_random_port() - results = self.spawn_ddp_children(model) + # ddp + elif self.distributed_backend == 'ddp': + self.set_random_port() + self.accelerator_backend = DDPBackend(self) + results = self.accelerator_backend.spawn_ddp_children(model) + # dp elif self.use_dp: self.accelerator_backend = DataParallelBackend(self) self.accelerator_backend.setup(model)