Use full `torch.distributed` import (#8200)

This commit is contained in:
Carlos Mocholí 2021-06-30 00:44:10 +02:00 committed by GitHub
parent 47c76548aa
commit df601405d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 12 deletions

View File

@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Union
import __main__
import numpy as np
import torch
import torch.distributed as torch_distrib
import torch.distributed
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import Optimizer
@ -307,7 +307,9 @@ class DDPPlugin(ParallelPlugin):
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
if not torch.distributed.is_initialized():
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
torch.distributed.init_process_group(
self.torch_distributed_backend, rank=global_rank, world_size=world_size
)
# on rank=0 let everyone know training is starting
rank_zero_info(
@ -333,12 +335,12 @@ class DDPPlugin(ParallelPlugin):
self.cluster_environment.teardown()
def barrier(self, *args, **kwargs) -> None:
if not torch_distrib.is_initialized():
if not torch.distributed.is_initialized():
return
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
torch_distrib.barrier(device_ids=self.determine_ddp_device_ids())
torch.distributed.barrier(device_ids=self.determine_ddp_device_ids())
else:
torch_distrib.barrier()
torch.distributed.barrier()
def broadcast(self, obj: object, src: int = 0) -> object:
return self.dist.broadcast(obj)

View File

@ -17,7 +17,7 @@ import re
from typing import Any, List, Optional, Union
import torch
import torch.distributed as torch_distrib
import torch.distributed
import torch.multiprocessing as mp
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import Optimizer
@ -258,7 +258,9 @@ class DDPSpawnPlugin(ParallelPlugin):
if not torch.distributed.is_initialized():
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
torch.distributed.init_process_group(
self.torch_distributed_backend, rank=global_rank, world_size=world_size
)
# on rank=0 let everyone know training is starting
rank_zero_info(
@ -310,12 +312,12 @@ class DDPSpawnPlugin(ParallelPlugin):
self.lightning_module.load_state_dict(ckpt)
def barrier(self, *args, **kwargs) -> None:
if not torch_distrib.is_initialized():
if not torch.distributed.is_initialized():
return
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
torch_distrib.barrier(device_ids=self.determine_ddp_device_ids())
torch.distributed.barrier(device_ids=self.determine_ddp_device_ids())
else:
torch_distrib.barrier()
torch.distributed.barrier()
def broadcast(self, obj: object, src: int = 0) -> object:
return self.dist.broadcast(obj)

View File

@ -15,7 +15,7 @@ from contextlib import ExitStack
from typing import Any, List, Optional, Union
import torch
import torch.distributed as torch_distrib
import torch.distributed
from torch.optim.lr_scheduler import _LRScheduler, Optimizer
from pytorch_lightning.core.optimizer import LightningOptimizer
@ -125,7 +125,7 @@ class HorovodPlugin(ParallelPlugin):
self.join()
def barrier(self, *args, **kwargs):
if torch_distrib.is_initialized():
if torch.distributed.is_initialized():
self.join()
def broadcast(self, obj: object, src: int = 0) -> object: