From f75f3bc1c65a2bb3388ffd68b1ee70ed9741525c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 2 Jan 2024 16:24:52 +0100 Subject: [PATCH] Simplify `_get_rank()` utility function (#19220) --- src/lightning/fabric/utilities/rank_zero.py | 7 +------ src/lightning/pytorch/callbacks/early_stopping.py | 9 ++------- tests/tests_pytorch/callbacks/test_early_stopping.py | 9 ++------- 3 files changed, 5 insertions(+), 20 deletions(-) diff --git a/src/lightning/fabric/utilities/rank_zero.py b/src/lightning/fabric/utilities/rank_zero.py index f64e1fc791..9120e9748b 100644 --- a/src/lightning/fabric/utilities/rank_zero.py +++ b/src/lightning/fabric/utilities/rank_zero.py @@ -30,17 +30,12 @@ from lightning_utilities.core.rank_zero import ( # noqa: F401 ) from typing_extensions import ParamSpec -import lightning.fabric from lightning.fabric.utilities.imports import _UTILITIES_GREATER_EQUAL_0_10 rank_zero_module.log = logging.getLogger(__name__) -def _get_rank( - strategy: Optional["lightning.fabric.strategies.Strategy"] = None, -) -> Optional[int]: - if strategy is not None: - return strategy.global_rank +def _get_rank() -> Optional[int]: # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, # therefore LOCAL_RANK needs to be checked first rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") diff --git a/src/lightning/pytorch/callbacks/early_stopping.py b/src/lightning/pytorch/callbacks/early_stopping.py index 76a0f28bd2..e44b17add3 100644 --- a/src/lightning/pytorch/callbacks/early_stopping.py +++ b/src/lightning/pytorch/callbacks/early_stopping.py @@ -26,7 +26,6 @@ from torch import Tensor from typing_extensions import override import lightning.pytorch as pl -from lightning.fabric.utilities.rank_zero import _get_rank from lightning.pytorch.callbacks.callback import Callback from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_warn @@ -265,12 +264,8 @@ class EarlyStopping(Callback): return msg @staticmethod - def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None: - rank = _get_rank( - strategy=(trainer.strategy if trainer is not None else None), # type: ignore[arg-type] - ) - if trainer is not None and trainer.world_size <= 1: - rank = None + def _log_info(trainer: "pl.Trainer", message: str, log_rank_zero_only: bool) -> None: + rank = trainer.global_rank if trainer.world_size > 1 else None message = rank_prefixed_message(message, rank) if rank is None or not log_rank_zero_only or rank == 0: log.info(message) diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index d3940cc8cc..8835abfb81 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -480,7 +480,6 @@ def test_early_stopping_squeezes(): es_mock.assert_called_once_with(torch.tensor(0)) -@pytest.mark.parametrize("trainer", [Trainer(), None]) @pytest.mark.parametrize( ("log_rank_zero_only", "world_size", "global_rank", "expected_log"), [ @@ -492,15 +491,11 @@ def test_early_stopping_squeezes(): (True, 2, 1, None), ], ) -def test_early_stopping_log_info(trainer, log_rank_zero_only, world_size, global_rank, expected_log): +def test_early_stopping_log_info(log_rank_zero_only, world_size, global_rank, expected_log): """Checks if log.info() gets called with expected message when used within EarlyStopping.""" # set the global_rank and world_size if trainer is not None # or else always expect the simple logging message - if trainer: - trainer.strategy.global_rank = global_rank - trainer.strategy.world_size = world_size - else: - expected_log = "bar" + trainer = Mock(global_rank=global_rank, world_size=world_size) with mock.patch("lightning.pytorch.callbacks.early_stopping.log.info") as log_mock: EarlyStopping._log_info(trainer, "bar", log_rank_zero_only)