Fix `rank_zero_only` rank not set in ddp-spawn based strategies (#19030)
This commit is contained in:
parent
4789905880
commit
f652e6c00e
|
@ -17,6 +17,7 @@ from typing import Any, ContextManager, Dict, List, Literal, Optional, Union
|
|||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
@ -202,7 +203,7 @@ class DDPStrategy(ParallelStrategy):
|
|||
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
|
||||
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
|
||||
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
|
||||
rank_zero_only.rank = self.global_rank
|
||||
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank
|
||||
|
||||
def _determine_ddp_device_ids(self) -> Optional[List[int]]:
|
||||
return None if self.root_device.type == "cpu" else [self.root_device.index]
|
||||
|
|
|
@ -34,6 +34,7 @@ from typing import (
|
|||
|
||||
import torch
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
|
||||
from torch import Tensor
|
||||
from torch.nn import Module, Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
@ -678,7 +679,7 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
|
|||
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
|
||||
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
|
||||
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
|
||||
rank_zero_only.rank = self.global_rank
|
||||
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank
|
||||
|
||||
|
||||
def _activation_checkpointing_kwargs(
|
||||
|
|
|
@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional,
|
|||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
@ -213,7 +214,7 @@ class DDPStrategy(ParallelStrategy):
|
|||
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
|
||||
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
|
||||
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
|
||||
rank_zero_only.rank = self.global_rank
|
||||
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank
|
||||
|
||||
def _register_ddp_hooks(self) -> None:
|
||||
log.debug(f"{self.__class__.__name__}: registering ddp hooks")
|
||||
|
|
|
@ -19,6 +19,7 @@ from pathlib import Path
|
|||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Mapping, Optional, Set, Type, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
|
@ -267,7 +268,7 @@ class FSDPStrategy(ParallelStrategy):
|
|||
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
|
||||
# `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail
|
||||
# additionally, for some implementations, the setter is a no-op, so it's safer to access the getter
|
||||
rank_zero_only.rank = self.global_rank
|
||||
rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank
|
||||
|
||||
@override
|
||||
def _configure_launcher(self) -> None:
|
||||
|
|
|
@ -25,12 +25,15 @@ from lightning.fabric.utilities.distributed import _distributed_is_initialized
|
|||
@pytest.fixture(autouse=True)
|
||||
def preserve_global_rank_variable():
|
||||
"""Ensures that the rank_zero_only.rank global variable gets reset in each test."""
|
||||
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
||||
from lightning.fabric.utilities.rank_zero import rank_zero_only as rank_zero_only_fabric
|
||||
from lightning_utilities.core.rank_zero import rank_zero_only as rank_zero_only_utilities
|
||||
|
||||
rank = getattr(rank_zero_only, "rank", None)
|
||||
functions = (rank_zero_only_fabric, rank_zero_only_utilities)
|
||||
ranks = [getattr(fn, "rank", None) for fn in functions]
|
||||
yield
|
||||
if rank is not None:
|
||||
setattr(rank_zero_only, "rank", rank)
|
||||
for fn, rank in zip(functions, ranks):
|
||||
if rank is not None:
|
||||
setattr(fn, "rank", rank)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
|
@ -41,12 +41,16 @@ def datadir():
|
|||
@pytest.fixture(autouse=True)
|
||||
def preserve_global_rank_variable():
|
||||
"""Ensures that the rank_zero_only.rank global variable gets reset in each test."""
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_only
|
||||
from lightning.fabric.utilities.rank_zero import rank_zero_only as rank_zero_only_fabric
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_only as rank_zero_only_pytorch
|
||||
from lightning_utilities.core.rank_zero import rank_zero_only as rank_zero_only_utilities
|
||||
|
||||
rank = getattr(rank_zero_only, "rank", None)
|
||||
functions = (rank_zero_only_pytorch, rank_zero_only_fabric, rank_zero_only_utilities)
|
||||
ranks = [getattr(fn, "rank", None) for fn in functions]
|
||||
yield
|
||||
if rank is not None:
|
||||
setattr(rank_zero_only, "rank", rank)
|
||||
for fn, rank in zip(functions, ranks):
|
||||
if rank is not None:
|
||||
setattr(fn, "rank", rank)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
|
@ -48,6 +48,7 @@ def test_torchscript_input_output(modelclass):
|
|||
@pytest.mark.parametrize("modelclass", [BoringModel, ParityModuleRNN, BasicGAN])
|
||||
def test_torchscript_example_input_output_trace(modelclass):
|
||||
"""Test that traced LightningModule forward works with example_input_array."""
|
||||
torch.manual_seed(1)
|
||||
model = modelclass()
|
||||
|
||||
if isinstance(model, BoringModel):
|
||||
|
|
Loading…
Reference in New Issue