Use full `torch.distributed` import (#8200)
This commit is contained in:
parent
47c76548aa
commit
df601405d9
|
@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||
import __main__
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
import torch.distributed
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
@ -307,7 +307,9 @@ class DDPPlugin(ParallelPlugin):
|
|||
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
|
||||
if not torch.distributed.is_initialized():
|
||||
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
|
||||
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
|
||||
torch.distributed.init_process_group(
|
||||
self.torch_distributed_backend, rank=global_rank, world_size=world_size
|
||||
)
|
||||
|
||||
# on rank=0 let everyone know training is starting
|
||||
rank_zero_info(
|
||||
|
@ -333,12 +335,12 @@ class DDPPlugin(ParallelPlugin):
|
|||
self.cluster_environment.teardown()
|
||||
|
||||
def barrier(self, *args, **kwargs) -> None:
|
||||
if not torch_distrib.is_initialized():
|
||||
if not torch.distributed.is_initialized():
|
||||
return
|
||||
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
|
||||
torch_distrib.barrier(device_ids=self.determine_ddp_device_ids())
|
||||
torch.distributed.barrier(device_ids=self.determine_ddp_device_ids())
|
||||
else:
|
||||
torch_distrib.barrier()
|
||||
torch.distributed.barrier()
|
||||
|
||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||
return self.dist.broadcast(obj)
|
||||
|
|
|
@ -17,7 +17,7 @@ import re
|
|||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
import torch.distributed
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
from torch.optim import Optimizer
|
||||
|
@ -258,7 +258,9 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
|
||||
if not torch.distributed.is_initialized():
|
||||
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
|
||||
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
|
||||
torch.distributed.init_process_group(
|
||||
self.torch_distributed_backend, rank=global_rank, world_size=world_size
|
||||
)
|
||||
|
||||
# on rank=0 let everyone know training is starting
|
||||
rank_zero_info(
|
||||
|
@ -310,12 +312,12 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
self.lightning_module.load_state_dict(ckpt)
|
||||
|
||||
def barrier(self, *args, **kwargs) -> None:
|
||||
if not torch_distrib.is_initialized():
|
||||
if not torch.distributed.is_initialized():
|
||||
return
|
||||
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
|
||||
torch_distrib.barrier(device_ids=self.determine_ddp_device_ids())
|
||||
torch.distributed.barrier(device_ids=self.determine_ddp_device_ids())
|
||||
else:
|
||||
torch_distrib.barrier()
|
||||
torch.distributed.barrier()
|
||||
|
||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||
return self.dist.broadcast(obj)
|
||||
|
|
|
@ -15,7 +15,7 @@ from contextlib import ExitStack
|
|||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
import torch.distributed
|
||||
from torch.optim.lr_scheduler import _LRScheduler, Optimizer
|
||||
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
|
@ -125,7 +125,7 @@ class HorovodPlugin(ParallelPlugin):
|
|||
self.join()
|
||||
|
||||
def barrier(self, *args, **kwargs):
|
||||
if torch_distrib.is_initialized():
|
||||
if torch.distributed.is_initialized():
|
||||
self.join()
|
||||
|
||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||
|
|
Loading…
Reference in New Issue