From f652e6c00e83fa79e670429d8e8fc5e8a63c8f28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 20 Nov 2023 16:49:14 +0100 Subject: [PATCH] Fix `rank_zero_only` rank not set in ddp-spawn based strategies (#19030) --- src/lightning/fabric/strategies/ddp.py | 3 ++- src/lightning/fabric/strategies/fsdp.py | 3 ++- src/lightning/pytorch/strategies/ddp.py | 3 ++- src/lightning/pytorch/strategies/fsdp.py | 3 ++- tests/tests_fabric/conftest.py | 11 +++++++---- tests/tests_pytorch/conftest.py | 12 ++++++++---- tests/tests_pytorch/models/test_torchscript.py | 1 + 7 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index f5d8bf0227..484260d7b0 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -17,6 +17,7 @@ from typing import Any, ContextManager, Dict, List, Literal, Optional, Union import torch import torch.distributed +from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only from torch import Tensor from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel @@ -202,7 +203,7 @@ class DDPStrategy(ParallelStrategy): self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter - rank_zero_only.rank = self.global_rank + rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank def _determine_ddp_device_ids(self) -> Optional[List[int]]: return None if self.root_device.type == "cpu" else [self.root_device.index] diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index ae125cb568..5bc8a31149 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -34,6 +34,7 @@ from typing import ( import torch from lightning_utilities.core.imports import RequirementCache +from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only from torch import Tensor from torch.nn import Module, Parameter from torch.optim import Optimizer @@ -678,7 +679,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded): self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter - rank_zero_only.rank = self.global_rank + rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank def _activation_checkpointing_kwargs( diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index 4689793756..278c6e40ed 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, import torch import torch.distributed +from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only from torch import Tensor from torch.nn import Module from torch.nn.parallel.distributed import DistributedDataParallel @@ -213,7 +214,7 @@ class DDPStrategy(ParallelStrategy): self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter - rank_zero_only.rank = self.global_rank + rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank def _register_ddp_hooks(self) -> None: log.debug(f"{self.__class__.__name__}: registering ddp hooks") diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 9c7ea3e222..c0f0d925ba 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -19,6 +19,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Mapping, Optional, Set, Type, Union import torch +from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only from torch import Tensor from torch.nn import Module from torch.optim import Optimizer @@ -267,7 +268,7 @@ class FSDPStrategy(ParallelStrategy): self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter - rank_zero_only.rank = self.global_rank + rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank @override def _configure_launcher(self) -> None: diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index f0be9263cb..fd0100dddc 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -25,12 +25,15 @@ from lightning.fabric.utilities.distributed import _distributed_is_initialized @pytest.fixture(autouse=True) def preserve_global_rank_variable(): """Ensures that the rank_zero_only.rank global variable gets reset in each test.""" - from lightning.fabric.utilities.rank_zero import rank_zero_only + from lightning.fabric.utilities.rank_zero import rank_zero_only as rank_zero_only_fabric + from lightning_utilities.core.rank_zero import rank_zero_only as rank_zero_only_utilities - rank = getattr(rank_zero_only, "rank", None) + functions = (rank_zero_only_fabric, rank_zero_only_utilities) + ranks = [getattr(fn, "rank", None) for fn in functions] yield - if rank is not None: - setattr(rank_zero_only, "rank", rank) + for fn, rank in zip(functions, ranks): + if rank is not None: + setattr(fn, "rank", rank) @pytest.fixture(autouse=True) diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index f4074fb653..6cd1ed6ca8 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -41,12 +41,16 @@ def datadir(): @pytest.fixture(autouse=True) def preserve_global_rank_variable(): """Ensures that the rank_zero_only.rank global variable gets reset in each test.""" - from lightning.pytorch.utilities.rank_zero import rank_zero_only + from lightning.fabric.utilities.rank_zero import rank_zero_only as rank_zero_only_fabric + from lightning.pytorch.utilities.rank_zero import rank_zero_only as rank_zero_only_pytorch + from lightning_utilities.core.rank_zero import rank_zero_only as rank_zero_only_utilities - rank = getattr(rank_zero_only, "rank", None) + functions = (rank_zero_only_pytorch, rank_zero_only_fabric, rank_zero_only_utilities) + ranks = [getattr(fn, "rank", None) for fn in functions] yield - if rank is not None: - setattr(rank_zero_only, "rank", rank) + for fn, rank in zip(functions, ranks): + if rank is not None: + setattr(fn, "rank", rank) @pytest.fixture(autouse=True) diff --git a/tests/tests_pytorch/models/test_torchscript.py b/tests/tests_pytorch/models/test_torchscript.py index a6ba208470..806b4db998 100644 --- a/tests/tests_pytorch/models/test_torchscript.py +++ b/tests/tests_pytorch/models/test_torchscript.py @@ -48,6 +48,7 @@ def test_torchscript_input_output(modelclass): @pytest.mark.parametrize("modelclass", [BoringModel, ParityModuleRNN, BasicGAN]) def test_torchscript_example_input_output_trace(modelclass): """Test that traced LightningModule forward works with example_input_array.""" + torch.manual_seed(1) model = modelclass() if isinstance(model, BoringModel):