diff --git a/pytorch_lightning/accelerators/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py index c3f32004f4..caa6fbea9c 100644 --- a/pytorch_lightning/accelerators/ddp_backend.py +++ b/pytorch_lightning/accelerators/ddp_backend.py @@ -21,9 +21,17 @@ from typing import Optional import numpy as np import torch +import torch.distributed as torch_distrib +import torch.distributed as dist from pytorch_lightning.utilities.distributed import find_free_network_port -from pytorch_lightning.accelerators.ddp_base_backend import DDPBase +from pytorch_lightning.accelerators.base_backend import Accelerator +from pytorch_lightning import _logger as log +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning.distributed.dist import LightningDistributed + try: from hydra.utils import to_absolute_path, get_original_cwd @@ -34,13 +42,14 @@ else: HYDRA_AVAILABLE = True -class DDPBackend(DDPBase): +class DDPBackend(Accelerator): def __init__(self, trainer, mode: str = 'ddp'): super().__init__(trainer) self.task_idx = None self._has_spawned_children = False self.mode = mode + self.dist = LightningDistributed() def setup(self, model): if self.mode == 'ddp': @@ -130,11 +139,11 @@ class DDPBackend(DDPBase): def train(self): model = self.trainer.model if self.mode == 'ddp': - results = self.ddp_train_tmp(process_idx=self.task_idx, mp_queue=None, model=model, is_master=True) + results = self.ddp_train(process_idx=self.task_idx, model=model, is_master=True) del os.environ['WORLD_SIZE'] return results else: - self.ddp_train_tmp(process_idx=self.task_idx, mp_queue=None, model=model) + self.ddp_train(process_idx=self.task_idx, model=model) def _check_can_spawn_children(self): if self._has_spawned_children: @@ -168,3 +177,117 @@ class DDPBackend(DDPBase): def get_device_ids(self): device_ids = [self.trainer.root_gpu] return device_ids + + def training_step(self, args): + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.trainer.model(*args) + else: + output = self.trainer.model(*args) + return output + + def validation_step(self, args): + output = self.training_step(args) + return output + + def test_step(self, args): + output = self.training_step(args) + return output + + def barrier(self, name: str = None): + if torch_distrib.is_initialized(): + torch_distrib.barrier() + + def early_stopping_should_stop(self, pl_module): + stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) + dist.all_reduce(stop, op=dist.reduce_op.SUM) + dist.barrier() + should_stop = stop == self.trainer.world_size + return should_stop + + def broadcast(self, obj, src=0): + return self.dist.broadcast(obj) + + def ddp_train(self, process_idx, model, is_master=False, proc_offset=0): + """ + Entry point for ddp + + Args: + process_idx: + mp_queue: multiprocessing queue + model: + + Returns: + + """ + seed = os.environ.get("PL_GLOBAL_SEED") + if seed is not None: + seed_everything(int(seed)) + + # offset the process id if requested + process_idx = process_idx + proc_offset + + # show progressbar only on progress_rank 0 + if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: + self.trainer.progress_bar_callback.disable() + + # determine which process we are and world size + self.set_world_ranks(process_idx) + + # set warning rank + rank_zero_only.rank = self.trainer.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + model.trainer = self.trainer + model.init_ddp_connection( + self.trainer.global_rank, + self.trainer.world_size, + self.trainer.is_slurm_managing_tasks + ) + + # call setup after the ddp process has connected + self.trainer.call_setup_hook(model) + + # on world_size=0 let everyone know training is starting + if self.trainer.is_global_zero and not torch.distributed.is_initialized(): + log.info('-' * 100) + log.info(f'distributed_backend={self.trainer.distributed_backend}') + log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') + log.info('-' * 100) + + # call sync_bn before .cuda(), configure_apex and configure_ddp + if self.trainer.sync_batchnorm: + model = model.configure_sync_batchnorm(model) + + # move the model to the correct device + self.model_to_device(model, process_idx, is_master) + + # CHOOSE OPTIMIZER + # allow for lr schedulers as well + self.setup_optimizers(model) + + # set model properties before going into wrapper + self.trainer.model_connector.copy_trainer_model_properties(model) + + # 16-bit + model = self.trainer.precision_connector.connect(model) + + # device ids change depending on the DDP setup + device_ids = self.get_device_ids() + + # allow user to configure ddp + model = model.configure_ddp(model, device_ids) + + # set up training routine + self.trainer.train_loop.setup_training(model) + + # train or test + results = self.train_or_test() + + # clean up memory + torch.cuda.empty_cache() + + if self.trainer.global_rank == 0: + return results