Fix the number of nodes not defined properly (#19482)
This commit is contained in:
parent
b28b673e68
commit
6cb5813a5e
|
@ -31,21 +31,19 @@ class _DistributedEnv:
|
|||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
world_size = torch.distributed.get_world_size()
|
||||
global_rank = torch.distributed.get_rank()
|
||||
# Note: On multi node CPU, the number of nodes won't be correct.
|
||||
num_nodes = world_size // torch.cuda.device_count() if torch.cuda.is_available() else world_size
|
||||
if torch.cuda.is_available() and world_size % torch.cuda.device_count() != 0:
|
||||
raise RuntimeError("The world size should be divisible by the number of GPUs.")
|
||||
else:
|
||||
world_size = None
|
||||
global_rank = 0
|
||||
num_nodes = 1
|
||||
|
||||
if world_size is None or world_size == -1:
|
||||
world_size = 1
|
||||
|
||||
# TODO: Add support for other accelerators
|
||||
num_nodes = (world_size // torch.cuda.device_count()) if torch.cuda.is_available() else 1
|
||||
|
||||
if num_nodes > 1:
|
||||
# validate the world size is divisble by the number of GPUs
|
||||
assert world_size % torch.cuda.device_count() == 0
|
||||
|
||||
return cls(world_size=world_size, global_rank=global_rank, num_nodes=max(1, num_nodes))
|
||||
return cls(world_size=world_size, global_rank=global_rank, num_nodes=num_nodes)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(world_size: {self.world_size}, global_rank: {self.global_rank}\n)"
|
||||
|
|
Loading…
Reference in New Issue