From df601405d93405fb7db215a77b18394e17a6daa9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 30 Jun 2021 00:44:10 +0200 Subject: [PATCH] Use full `torch.distributed` import (#8200) --- pytorch_lightning/plugins/training_type/ddp.py | 12 +++++++----- pytorch_lightning/plugins/training_type/ddp_spawn.py | 12 +++++++----- pytorch_lightning/plugins/training_type/horovod.py | 4 ++-- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index a24bfc3497..1f99d1b09f 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Union import __main__ import numpy as np import torch -import torch.distributed as torch_distrib +import torch.distributed from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer @@ -307,7 +307,9 @@ class DDPPlugin(ParallelPlugin): os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) if not torch.distributed.is_initialized(): log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") - torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) + torch.distributed.init_process_group( + self.torch_distributed_backend, rank=global_rank, world_size=world_size + ) # on rank=0 let everyone know training is starting rank_zero_info( @@ -333,12 +335,12 @@ class DDPPlugin(ParallelPlugin): self.cluster_environment.teardown() def barrier(self, *args, **kwargs) -> None: - if not torch_distrib.is_initialized(): + if not torch.distributed.is_initialized(): return if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": - torch_distrib.barrier(device_ids=self.determine_ddp_device_ids()) + torch.distributed.barrier(device_ids=self.determine_ddp_device_ids()) else: - torch_distrib.barrier() + torch.distributed.barrier() def broadcast(self, obj: object, src: int = 0) -> object: return self.dist.broadcast(obj) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index ac150323d0..6c7b24d6fe 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -17,7 +17,7 @@ import re from typing import Any, List, Optional, Union import torch -import torch.distributed as torch_distrib +import torch.distributed import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import Optimizer @@ -258,7 +258,9 @@ class DDPSpawnPlugin(ParallelPlugin): if not torch.distributed.is_initialized(): log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") - torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) + torch.distributed.init_process_group( + self.torch_distributed_backend, rank=global_rank, world_size=world_size + ) # on rank=0 let everyone know training is starting rank_zero_info( @@ -310,12 +312,12 @@ class DDPSpawnPlugin(ParallelPlugin): self.lightning_module.load_state_dict(ckpt) def barrier(self, *args, **kwargs) -> None: - if not torch_distrib.is_initialized(): + if not torch.distributed.is_initialized(): return if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl": - torch_distrib.barrier(device_ids=self.determine_ddp_device_ids()) + torch.distributed.barrier(device_ids=self.determine_ddp_device_ids()) else: - torch_distrib.barrier() + torch.distributed.barrier() def broadcast(self, obj: object, src: int = 0) -> object: return self.dist.broadcast(obj) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 99899aed11..cbd9e80dab 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -15,7 +15,7 @@ from contextlib import ExitStack from typing import Any, List, Optional, Union import torch -import torch.distributed as torch_distrib +import torch.distributed from torch.optim.lr_scheduler import _LRScheduler, Optimizer from pytorch_lightning.core.optimizer import LightningOptimizer @@ -125,7 +125,7 @@ class HorovodPlugin(ParallelPlugin): self.join() def barrier(self, *args, **kwargs): - if torch_distrib.is_initialized(): + if torch.distributed.is_initialized(): self.join() def broadcast(self, obj: object, src: int = 0) -> object: