From 6cb5813a5e7a93f2c6971229abce8441d965305c Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 15 Feb 2024 17:35:26 +0000 Subject: [PATCH] Fix the number of nodes not defined properly (#19482) --- src/lightning/data/utilities/env.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/lightning/data/utilities/env.py b/src/lightning/data/utilities/env.py index 027346d216..37fbf7bb25 100644 --- a/src/lightning/data/utilities/env.py +++ b/src/lightning/data/utilities/env.py @@ -31,21 +31,19 @@ class _DistributedEnv: if torch.distributed.is_available() and torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() global_rank = torch.distributed.get_rank() + # Note: On multi node CPU, the number of nodes won't be correct. + num_nodes = world_size // torch.cuda.device_count() if torch.cuda.is_available() else world_size + if torch.cuda.is_available() and world_size % torch.cuda.device_count() != 0: + raise RuntimeError("The world size should be divisible by the number of GPUs.") else: world_size = None global_rank = 0 + num_nodes = 1 if world_size is None or world_size == -1: world_size = 1 - # TODO: Add support for other accelerators - num_nodes = (world_size // torch.cuda.device_count()) if torch.cuda.is_available() else 1 - - if num_nodes > 1: - # validate the world size is divisble by the number of GPUs - assert world_size % torch.cuda.device_count() == 0 - - return cls(world_size=world_size, global_rank=global_rank, num_nodes=max(1, num_nodes)) + return cls(world_size=world_size, global_rank=global_rank, num_nodes=num_nodes) def __repr__(self) -> str: return f"{self.__class__.__name__}(world_size: {self.world_size}, global_rank: {self.global_rank}\n)"