From 0838c6bfce0b76cd4c660807333814cafab8cd14 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 3 Oct 2020 09:02:35 -0400 Subject: [PATCH] ref: decoupled ddp2 (#3816) --- .../accelerators/ddp2_backend.py | 124 +++++++++-- .../accelerators/ddp_base_backend.py | 193 ------------------ 2 files changed, 108 insertions(+), 209 deletions(-) delete mode 100644 pytorch_lightning/accelerators/ddp_base_backend.py diff --git a/pytorch_lightning/accelerators/ddp2_backend.py b/pytorch_lightning/accelerators/ddp2_backend.py index 28af6558a0..90ed0ff54c 100644 --- a/pytorch_lightning/accelerators/ddp2_backend.py +++ b/pytorch_lightning/accelerators/ddp2_backend.py @@ -15,10 +15,14 @@ import os import torch - +import torch.distributed as torch_distrib from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.core.step_result import Result -from pytorch_lightning.accelerators.ddp_base_backend import DDPBase +from pytorch_lightning.distributed.dist import LightningDistributed +from pytorch_lightning import _logger as log +from pytorch_lightning.accelerators.base_backend import Accelerator +from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities.distributed import rank_zero_only try: from hydra.utils import to_absolute_path, get_original_cwd @@ -29,15 +33,15 @@ else: HYDRA_AVAILABLE = True -class DDP2Backend(DDPBase): +class DDP2Backend(Accelerator): def __init__(self, trainer): super().__init__(trainer) self.task_idx = None + self.dist = LightningDistributed() def setup(self, model): self._resolve_task_idx() - self.trainer.model = model def _resolve_task_idx(self): @@ -53,7 +57,27 @@ class DDP2Backend(DDPBase): def train(self): model = self.trainer.model - self.ddp_train_tmp(process_idx=self.task_idx, mp_queue=None, model=model) + return self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model) + + def training_step(self, args): + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.trainer.model(*args) + else: + output = self.trainer.model(*args) + return output + + def validation_step(self, args): + output = self.training_step(args) + return output + + def test_step(self, args): + output = self.training_step(args) + return output + + def barrier(self, name: str = None): + if torch_distrib.is_initialized(): + torch_distrib.barrier() def training_step_end(self, output): if isinstance(output, Result): @@ -75,21 +99,89 @@ class DDP2Backend(DDPBase): self.trainer.global_rank = self.trainer.node_rank self.trainer.world_size = self.trainer.num_nodes - def model_to_device(self, model, process_idx, is_master): - gpu_idx = process_idx + def broadcast(self, obj, src=0): + return self.dist.broadcast(obj) - # when using ddp, the master process (proc 0) continues running as the main one - # this means that the local rank will always be 0 - # (even if cuda visible devices has other visible gpus) - # this means that the master process needs to pull the 0th visible index as the device number - if is_master: - available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') - gpu_idx = int(available_gpus[self.trainer.local_rank]) - - self.trainer.root_gpu = gpu_idx + def model_to_device(self, model, process_idx): + self.trainer.root_gpu = process_idx torch.cuda.set_device(self.trainer.root_gpu) model.cuda(self.trainer.root_gpu) def get_device_ids(self): device_ids = self.trainer.data_parallel_device_ids return device_ids + + def ddp_train(self, process_idx, mp_queue, model): + """ + Entry point for ddp + + Args: + process_idx: + mp_queue: multiprocessing queue + model: + + Returns: + + """ + # show progressbar only on progress_rank 0 + if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: + self.trainer.progress_bar_callback.disable() + + # determine which process we are and world size + self.set_world_ranks(process_idx) + + # set warning rank + rank_zero_only.rank = self.trainer.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + model.trainer = self.trainer + model.init_ddp_connection( + self.trainer.global_rank, + self.trainer.world_size, + self.trainer.is_slurm_managing_tasks + ) + + # call setup after the ddp process has connected + self.trainer.call_setup_hook(model) + + # on world_size=0 let everyone know training is starting + if self.trainer.is_global_zero and not torch.distributed.is_initialized(): + log.info('-' * 100) + log.info(f'distributed_backend={self.trainer.distributed_backend}') + log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') + log.info('-' * 100) + + # call sync_bn before .cuda(), configure_apex and configure_ddp + if self.trainer.sync_batchnorm: + model = model.configure_sync_batchnorm(model) + + # move the model to the correct device + self.model_to_device(model, process_idx) + + # CHOOSE OPTIMIZER + # allow for lr schedulers as well + self.setup_optimizers(model) + + # set model properties before going into wrapper + self.trainer.model_connector.copy_trainer_model_properties(model) + + # 16-bit + model = self.trainer.precision_connector.connect(model) + + # device ids change depending on the DDP setup + device_ids = self.get_device_ids() + + # allow user to configure ddp + model = model.configure_ddp(model, device_ids) + + # set up training routine + self.trainer.train_loop.setup_training(model) + + # train or test + results = self.train_or_test() + + # clean up memory + torch.cuda.empty_cache() + return results diff --git a/pytorch_lightning/accelerators/ddp_base_backend.py b/pytorch_lightning/accelerators/ddp_base_backend.py deleted file mode 100644 index b27908c74e..0000000000 --- a/pytorch_lightning/accelerators/ddp_base_backend.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License -import os -import re - -import torch -import torch.distributed as torch_distrib -import torch.distributed as dist - -from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.base_backend import Accelerator -from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.cloud_io import atomic_save -from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn -from pytorch_lightning.utilities.seed import seed_everything -from pytorch_lightning.distributed.dist import LightningDistributed - -try: - from hydra.core.hydra_config import HydraConfig - from hydra.utils import get_original_cwd, to_absolute_path -except ImportError: - HYDRA_AVAILABLE = False -else: - HYDRA_AVAILABLE = True - - -class DDPBase(Accelerator): - - def __init__(self, trainer): - super().__init__(trainer) - self.dist = LightningDistributed() - - def training_step(self, args): - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.trainer.model(*args) - else: - output = self.trainer.model(*args) - return output - - def validation_step(self, args): - output = self.training_step(args) - return output - - def test_step(self, args): - output = self.training_step(args) - return output - - def barrier(self, name: str = None): - if torch_distrib.is_initialized(): - torch_distrib.barrier() - - def early_stopping_should_stop(self, pl_module): - stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) - dist.all_reduce(stop, op=dist.reduce_op.SUM) - dist.barrier() - should_stop = stop == self.trainer.world_size - return should_stop - - def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): - if self.trainer.distributed_backend not in ("ddp_spawn", "ddp_cpu", "tpu"): - return - - # track the best model path - best_model_path = None - if self.trainer.checkpoint_callback is not None: - best_model_path = self.trainer.checkpoint_callback.best_model_path - - if self.trainer.global_rank == 0 and mp_queue is not None: - rank_zero_warn('cleaning up ddp environment...') - # todo, pass complete checkpoint as state dictionary - mp_queue.put(best_model_path) - mp_queue.put(results) - - # save the last weights - last_path = None - if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) - atomic_save(model.state_dict(), last_path) - mp_queue.put(last_path) - - def ddp_train_tmp(self, process_idx, mp_queue, model, is_master=False, proc_offset=0): - """ - Entry point for ddp - - Args: - process_idx: - mp_queue: multiprocessing queue - model: - - Returns: - - """ - seed = os.environ.get("PL_GLOBAL_SEED") - if seed is not None: - seed_everything(int(seed)) - - # offset the process id if requested - process_idx = process_idx + proc_offset - - # show progressbar only on progress_rank 0 - if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: - self.trainer.progress_bar_callback.disable() - - # determine which process we are and world size - self.set_world_ranks(process_idx) - - # set warning rank - rank_zero_only.rank = self.trainer.global_rank - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - model.trainer = self.trainer - model.init_ddp_connection( - self.trainer.global_rank, - self.trainer.world_size, - self.trainer.is_slurm_managing_tasks - ) - - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - # on world_size=0 let everyone know training is starting - if self.trainer.is_global_zero and not torch.distributed.is_initialized(): - log.info('-' * 100) - log.info(f'distributed_backend={self.trainer.distributed_backend}') - log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') - log.info('-' * 100) - - # call sync_bn before .cuda(), configure_apex and configure_ddp - if self.trainer.sync_batchnorm: - model = model.configure_sync_batchnorm(model) - - # move the model to the correct device - self.model_to_device(model, process_idx, is_master) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - self.setup_optimizers(model) - - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - - # 16-bit - model = self.trainer.precision_connector.connect(model) - - # device ids change depending on the DDP setup - device_ids = self.get_device_ids() - - # allow user to configure ddp - model = model.configure_ddp(model, device_ids) - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or test - results = self.train_or_test() - - # get original model - model = self.trainer.get_model() - - # persist info in ddp_spawn - self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) - - # clean up memory - torch.cuda.empty_cache() - - if self.trainer.global_rank == 0: - return results - - def broadcast(self, obj, src=0): - return self.dist.broadcast(obj) - - def set_world_ranks(self, process_idx): - raise NotImplementedError('to create a ddp backend, please implement set_world_ranks') - - def model_to_device(self, model, process_idx, is_master): - raise NotImplementedError('to create a ddp backend, please implement model_to_device') - - def get_device_ids(self): - raise NotImplementedError('to create a ddp backend, please implement get_device_ids')