diff --git a/pytorch_lightning/accelerators/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py index 4b63b69c6b..614d38bc4d 100644 --- a/pytorch_lightning/accelerators/ddp_backend.py +++ b/pytorch_lightning/accelerators/ddp_backend.py @@ -133,124 +133,11 @@ class DDPBackend(DDPBase): def train(self): model = self.trainer.model if self.mode == 'ddp': - results = self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model, is_master=True) + results = self.ddp_train_tmp(process_idx=self.task_idx, mp_queue=None, model=model, is_master=True) del os.environ['WORLD_SIZE'] return results else: - self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model) - - def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0): - """ - Entry point for ddp - - Args: - process_idx: - mp_queue: multiprocessing queue - model: - is_master: - proc_offset: - - Returns: - - """ - # offset the process id if requested - process_idx = process_idx + proc_offset - - # show progressbar only on progress_rank 0 - if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: - self.trainer.progress_bar_callback.disable() - - # determine which process we are and world size - self.trainer.local_rank = process_idx - self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx - self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes - - # set warning rank - rank_zero_only.rank = self.trainer.global_rank - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - model.trainer = self.trainer - model.init_ddp_connection( - self.trainer.global_rank, - self.trainer.world_size, - self.trainer.is_slurm_managing_tasks - ) - - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - # on world_size=0 let everyone know training is starting - if self.trainer.is_global_zero: - log.info('-' * 100) - log.info(f'distributed_backend={self.trainer.distributed_backend}') - log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') - log.info('-' * 100) - - # call sync_bn before .cuda(), configure_apex and configure_ddp - if self.trainer.sync_batchnorm: - model = model.configure_sync_batchnorm(model) - - # MODEL - # copy model to each gpu - if self.trainer.on_gpu: - gpu_idx = process_idx - - # when using ddp, the master process (proc 0) continues running as the main one - # this means that the local rank will always be 0 - # (even if cuda visible devices has other visible gpus) - # this means that the master process needs to pull the 0th visible index as the device number - if is_master: - available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') - gpu_idx = int(available_gpus[self.trainer.local_rank]) - - self.trainer.root_gpu = gpu_idx - torch.cuda.set_device(self.trainer.root_gpu) - model.cuda(self.trainer.root_gpu) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model) - self.trainer.optimizers = optimizers - self.trainer.lr_schedulers = lr_schedulers - self.trainer.optimizer_frequencies = optimizer_frequencies - - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - - # AMP - run through amp wrapper before going to distributed DP - if self.trainer.amp_backend == AMPType.APEX: - model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) - self.trainer.optimizers = optimizers - self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers) - - # DDP2 uses all GPUs on the machine - if self.trainer.distributed_backend == 'ddp' or self.trainer.distributed_backend == 'ddp_spawn': - device_ids = [self.trainer.root_gpu] - else: # includes ddp_cpu - device_ids = None - - # allow user to configure ddp - model = model.configure_ddp(model, device_ids) - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or test - results = self.train_or_test() - - # get original model - model = self.trainer.get_model() - - # persist info in ddp_spawn - self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) - - # clean up memory - torch.cuda.empty_cache() - - if self.trainer.global_rank == 0 and self.trainer.distributed_backend not in ['ddp_spawn', 'ddp_cpu']: - return results + self.ddp_train_tmp(process_idx=self.task_idx, mp_queue=None, model=model) def _check_can_spawn_children(self): if self._has_spawned_children: @@ -258,3 +145,27 @@ class DDPBackend(DDPBase): "You tried to run `.fit` or `.test` multiple times in the same script." " This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead." ) + + def set_world_ranks(self, process_idx): + self.trainer.local_rank = process_idx + self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx + self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes + + def model_to_device(self, model, process_idx, is_master): + gpu_idx = process_idx + + # when using ddp, the master process (proc 0) continues running as the main one + # this means that the local rank will always be 0 + # (even if cuda visible devices has other visible gpus) + # this means that the master process needs to pull the 0th visible index as the device number + if is_master: + available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') + gpu_idx = int(available_gpus[self.trainer.local_rank]) + + self.trainer.root_gpu = gpu_idx + torch.cuda.set_device(self.trainer.root_gpu) + model.cuda(self.trainer.root_gpu) + + def get_device_ids(self): + device_ids = [self.trainer.root_gpu] + return device_ids diff --git a/pytorch_lightning/accelerators/ddp_base_backend.py b/pytorch_lightning/accelerators/ddp_base_backend.py index e534da0fc5..ff51fbe786 100644 --- a/pytorch_lightning/accelerators/ddp_base_backend.py +++ b/pytorch_lightning/accelerators/ddp_base_backend.py @@ -180,6 +180,9 @@ class DDPBase(Accelerator): # clean up memory torch.cuda.empty_cache() + if self.trainer.global_rank == 0: + return results + def set_world_ranks(self, process_idx): raise NotImplementedError('to create a ddp backend, please implement set_world_ranks') diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py index a986aa7960..ec67098dc7 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py @@ -11,68 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License -import os - -import torch -import torch.multiprocessing as mp - -from pytorch_lightning.utilities.distributed import find_free_network_port -from pytorch_lightning.accelerators.ddp_base_backend import DDPBase - -try: - from apex import amp -except ImportError: - amp = None +from pytorch_lightning.accelerators.ddp_spawn_backend import DDPSpawnBackend -class DDPCPUSpawnBackend(DDPBase): - - def __init__(self, trainer, nprocs): - super().__init__(trainer) - self.mp_queue = None - self.nprocs = nprocs - - def setup(self, model): - os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port())) - - # pass in a state q - smp = mp.get_context('spawn') - self.mp_queue = smp.SimpleQueue() - - self.trainer.model = model - - def train(self): - model = self.trainer.model - - # train in children process - mp.spawn(self.ddp_train_tmp, nprocs=self.nprocs, args=(self.mp_queue, model,)) - - # restore main state with best weights - best_path = self.mp_queue.get() - results = self.mp_queue.get() - last_path = self.mp_queue.get() - - # recover the weights of the processes trained in the children - self.__recover_child_process_weights(model, best_path, last_path) - return results - - def __recover_child_process_weights(self, model, best_path, last_path): - # transfer back the best path to the trainer - if self.trainer.checkpoint_callback: - self.trainer.checkpoint_callback.best_model_path = best_path - # todo, pass also best score - - # load last weights - if last_path is not None and not self.trainer.testing: - ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) - model.load_state_dict(ckpt) - - self.trainer.model = model - - def set_world_ranks(self, process_idx): - self.trainer.local_rank = process_idx - self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx - self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes +class DDPCPUSpawnBackend(DDPSpawnBackend): def model_to_device(self, model, process_idx, is_master): pass