Simplify `_get_rank()` utility function (#19220)
This commit is contained in:
parent
564be3b521
commit
f75f3bc1c6
|
@ -30,17 +30,12 @@ from lightning_utilities.core.rank_zero import ( # noqa: F401
|
|||
)
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import lightning.fabric
|
||||
from lightning.fabric.utilities.imports import _UTILITIES_GREATER_EQUAL_0_10
|
||||
|
||||
rank_zero_module.log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_rank(
|
||||
strategy: Optional["lightning.fabric.strategies.Strategy"] = None,
|
||||
) -> Optional[int]:
|
||||
if strategy is not None:
|
||||
return strategy.global_rank
|
||||
def _get_rank() -> Optional[int]:
|
||||
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
|
||||
# therefore LOCAL_RANK needs to be checked first
|
||||
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
|
||||
|
|
|
@ -26,7 +26,6 @@ from torch import Tensor
|
|||
from typing_extensions import override
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.rank_zero import _get_rank
|
||||
from lightning.pytorch.callbacks.callback import Callback
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_warn
|
||||
|
@ -265,12 +264,8 @@ class EarlyStopping(Callback):
|
|||
return msg
|
||||
|
||||
@staticmethod
|
||||
def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None:
|
||||
rank = _get_rank(
|
||||
strategy=(trainer.strategy if trainer is not None else None), # type: ignore[arg-type]
|
||||
)
|
||||
if trainer is not None and trainer.world_size <= 1:
|
||||
rank = None
|
||||
def _log_info(trainer: "pl.Trainer", message: str, log_rank_zero_only: bool) -> None:
|
||||
rank = trainer.global_rank if trainer.world_size > 1 else None
|
||||
message = rank_prefixed_message(message, rank)
|
||||
if rank is None or not log_rank_zero_only or rank == 0:
|
||||
log.info(message)
|
||||
|
|
|
@ -480,7 +480,6 @@ def test_early_stopping_squeezes():
|
|||
es_mock.assert_called_once_with(torch.tensor(0))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("trainer", [Trainer(), None])
|
||||
@pytest.mark.parametrize(
|
||||
("log_rank_zero_only", "world_size", "global_rank", "expected_log"),
|
||||
[
|
||||
|
@ -492,15 +491,11 @@ def test_early_stopping_squeezes():
|
|||
(True, 2, 1, None),
|
||||
],
|
||||
)
|
||||
def test_early_stopping_log_info(trainer, log_rank_zero_only, world_size, global_rank, expected_log):
|
||||
def test_early_stopping_log_info(log_rank_zero_only, world_size, global_rank, expected_log):
|
||||
"""Checks if log.info() gets called with expected message when used within EarlyStopping."""
|
||||
# set the global_rank and world_size if trainer is not None
|
||||
# or else always expect the simple logging message
|
||||
if trainer:
|
||||
trainer.strategy.global_rank = global_rank
|
||||
trainer.strategy.world_size = world_size
|
||||
else:
|
||||
expected_log = "bar"
|
||||
trainer = Mock(global_rank=global_rank, world_size=world_size)
|
||||
|
||||
with mock.patch("lightning.pytorch.callbacks.early_stopping.log.info") as log_mock:
|
||||
EarlyStopping._log_info(trainer, "bar", log_rank_zero_only)
|
||||
|
|
Loading…
Reference in New Issue