Simplify `_get_rank()` utility function (#19220)

This commit is contained in:
awaelchli 2024-01-02 16:24:52 +01:00 committed by GitHub
parent 564be3b521
commit f75f3bc1c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 5 additions and 20 deletions

View File

@ -30,17 +30,12 @@ from lightning_utilities.core.rank_zero import ( # noqa: F401
) )
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
import lightning.fabric
from lightning.fabric.utilities.imports import _UTILITIES_GREATER_EQUAL_0_10 from lightning.fabric.utilities.imports import _UTILITIES_GREATER_EQUAL_0_10
rank_zero_module.log = logging.getLogger(__name__) rank_zero_module.log = logging.getLogger(__name__)
def _get_rank( def _get_rank() -> Optional[int]:
strategy: Optional["lightning.fabric.strategies.Strategy"] = None,
) -> Optional[int]:
if strategy is not None:
return strategy.global_rank
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
# therefore LOCAL_RANK needs to be checked first # therefore LOCAL_RANK needs to be checked first
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")

View File

@ -26,7 +26,6 @@ from torch import Tensor
from typing_extensions import override from typing_extensions import override
import lightning.pytorch as pl import lightning.pytorch as pl
from lightning.fabric.utilities.rank_zero import _get_rank
from lightning.pytorch.callbacks.callback import Callback from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_warn from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_warn
@ -265,12 +264,8 @@ class EarlyStopping(Callback):
return msg return msg
@staticmethod @staticmethod
def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None: def _log_info(trainer: "pl.Trainer", message: str, log_rank_zero_only: bool) -> None:
rank = _get_rank( rank = trainer.global_rank if trainer.world_size > 1 else None
strategy=(trainer.strategy if trainer is not None else None), # type: ignore[arg-type]
)
if trainer is not None and trainer.world_size <= 1:
rank = None
message = rank_prefixed_message(message, rank) message = rank_prefixed_message(message, rank)
if rank is None or not log_rank_zero_only or rank == 0: if rank is None or not log_rank_zero_only or rank == 0:
log.info(message) log.info(message)

View File

@ -480,7 +480,6 @@ def test_early_stopping_squeezes():
es_mock.assert_called_once_with(torch.tensor(0)) es_mock.assert_called_once_with(torch.tensor(0))
@pytest.mark.parametrize("trainer", [Trainer(), None])
@pytest.mark.parametrize( @pytest.mark.parametrize(
("log_rank_zero_only", "world_size", "global_rank", "expected_log"), ("log_rank_zero_only", "world_size", "global_rank", "expected_log"),
[ [
@ -492,15 +491,11 @@ def test_early_stopping_squeezes():
(True, 2, 1, None), (True, 2, 1, None),
], ],
) )
def test_early_stopping_log_info(trainer, log_rank_zero_only, world_size, global_rank, expected_log): def test_early_stopping_log_info(log_rank_zero_only, world_size, global_rank, expected_log):
"""Checks if log.info() gets called with expected message when used within EarlyStopping.""" """Checks if log.info() gets called with expected message when used within EarlyStopping."""
# set the global_rank and world_size if trainer is not None # set the global_rank and world_size if trainer is not None
# or else always expect the simple logging message # or else always expect the simple logging message
if trainer: trainer = Mock(global_rank=global_rank, world_size=world_size)
trainer.strategy.global_rank = global_rank
trainer.strategy.world_size = world_size
else:
expected_log = "bar"
with mock.patch("lightning.pytorch.callbacks.early_stopping.log.info") as log_mock: with mock.patch("lightning.pytorch.callbacks.early_stopping.log.info") as log_mock:
EarlyStopping._log_info(trainer, "bar", log_rank_zero_only) EarlyStopping._log_info(trainer, "bar", log_rank_zero_only)