Fix the number of nodes not defined properly (#19482)

This commit is contained in:
thomas chaton 2024-02-15 17:35:26 +00:00 committed by GitHub
parent b28b673e68
commit 6cb5813a5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 8 deletions

View File

@ -31,21 +31,19 @@ class _DistributedEnv:
if torch.distributed.is_available() and torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
global_rank = torch.distributed.get_rank()
# Note: On multi node CPU, the number of nodes won't be correct.
num_nodes = world_size // torch.cuda.device_count() if torch.cuda.is_available() else world_size
if torch.cuda.is_available() and world_size % torch.cuda.device_count() != 0:
raise RuntimeError("The world size should be divisible by the number of GPUs.")
else:
world_size = None
global_rank = 0
num_nodes = 1
if world_size is None or world_size == -1:
world_size = 1
# TODO: Add support for other accelerators
num_nodes = (world_size // torch.cuda.device_count()) if torch.cuda.is_available() else 1
if num_nodes > 1:
# validate the world size is divisble by the number of GPUs
assert world_size % torch.cuda.device_count() == 0
return cls(world_size=world_size, global_rank=global_rank, num_nodes=max(1, num_nodes))
return cls(world_size=world_size, global_rank=global_rank, num_nodes=num_nodes)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(world_size: {self.world_size}, global_rank: {self.global_rank}\n)"