Fix `rank_zero_only` rank not set in ddp-spawn based strategies (#19030)

This commit is contained in:
Adrian Wälchli 2023-11-20 16:49:14 +01:00 committed by GitHub
parent 4789905880
commit f652e6c00e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 24 additions and 12 deletions

View File

@ -17,6 +17,7 @@ from typing import Any, ContextManager, Dict, List, Literal, Optional, Union
import torch
import torch.distributed
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
from torch import Tensor
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel
@ -202,7 +203,7 @@ class DDPStrategy(ParallelStrategy):
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = self.global_rank
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank
def _determine_ddp_device_ids(self) -> Optional[List[int]]:
return None if self.root_device.type == "cpu" else [self.root_device.index]

View File

@ -34,6 +34,7 @@ from typing import (
import torch
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
from torch import Tensor
from torch.nn import Module, Parameter
from torch.optim import Optimizer
@ -678,7 +679,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = self.global_rank
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank
def _activation_checkpointing_kwargs(

View File

@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional,
import torch
import torch.distributed
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
from torch import Tensor
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel
@ -213,7 +214,7 @@ class DDPStrategy(ParallelStrategy):
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = self.global_rank
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank
def _register_ddp_hooks(self) -> None:
log.debug(f"{self.__class__.__name__}: registering ddp hooks")

View File

@ -19,6 +19,7 @@ from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Mapping, Optional, Set, Type, Union
import torch
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
@ -267,7 +268,7 @@ class FSDPStrategy(ParallelStrategy):
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
rank_zero_only.rank = self.global_rank
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank
@override
def _configure_launcher(self) -> None:

View File

@ -25,12 +25,15 @@ from lightning.fabric.utilities.distributed import _distributed_is_initialized
@pytest.fixture(autouse=True)
def preserve_global_rank_variable():
"""Ensures that the rank_zero_only.rank global variable gets reset in each test."""
from lightning.fabric.utilities.rank_zero import rank_zero_only
from lightning.fabric.utilities.rank_zero import rank_zero_only as rank_zero_only_fabric
from lightning_utilities.core.rank_zero import rank_zero_only as rank_zero_only_utilities
rank = getattr(rank_zero_only, "rank", None)
functions = (rank_zero_only_fabric, rank_zero_only_utilities)
ranks = [getattr(fn, "rank", None) for fn in functions]
yield
if rank is not None:
setattr(rank_zero_only, "rank", rank)
for fn, rank in zip(functions, ranks):
if rank is not None:
setattr(fn, "rank", rank)
@pytest.fixture(autouse=True)

View File

@ -41,12 +41,16 @@ def datadir():
@pytest.fixture(autouse=True)
def preserve_global_rank_variable():
"""Ensures that the rank_zero_only.rank global variable gets reset in each test."""
from lightning.pytorch.utilities.rank_zero import rank_zero_only
from lightning.fabric.utilities.rank_zero import rank_zero_only as rank_zero_only_fabric
from lightning.pytorch.utilities.rank_zero import rank_zero_only as rank_zero_only_pytorch
from lightning_utilities.core.rank_zero import rank_zero_only as rank_zero_only_utilities
rank = getattr(rank_zero_only, "rank", None)
functions = (rank_zero_only_pytorch, rank_zero_only_fabric, rank_zero_only_utilities)
ranks = [getattr(fn, "rank", None) for fn in functions]
yield
if rank is not None:
setattr(rank_zero_only, "rank", rank)
for fn, rank in zip(functions, ranks):
if rank is not None:
setattr(fn, "rank", rank)
@pytest.fixture(autouse=True)

View File

@ -48,6 +48,7 @@ def test_torchscript_input_output(modelclass):
@pytest.mark.parametrize("modelclass", [BoringModel, ParityModuleRNN, BasicGAN])
def test_torchscript_example_input_output_trace(modelclass):
"""Test that traced LightningModule forward works with example_input_array."""
torch.manual_seed(1)
model = modelclass()
if isinstance(model, BoringModel):