Simplify `_get_rank()` utility function (#19220)
This commit is contained in:
parent
564be3b521
commit
f75f3bc1c6
|
@ -30,17 +30,12 @@ from lightning_utilities.core.rank_zero import ( # noqa: F401
|
||||||
)
|
)
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
import lightning.fabric
|
|
||||||
from lightning.fabric.utilities.imports import _UTILITIES_GREATER_EQUAL_0_10
|
from lightning.fabric.utilities.imports import _UTILITIES_GREATER_EQUAL_0_10
|
||||||
|
|
||||||
rank_zero_module.log = logging.getLogger(__name__)
|
rank_zero_module.log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_rank(
|
def _get_rank() -> Optional[int]:
|
||||||
strategy: Optional["lightning.fabric.strategies.Strategy"] = None,
|
|
||||||
) -> Optional[int]:
|
|
||||||
if strategy is not None:
|
|
||||||
return strategy.global_rank
|
|
||||||
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
|
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
|
||||||
# therefore LOCAL_RANK needs to be checked first
|
# therefore LOCAL_RANK needs to be checked first
|
||||||
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
|
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
|
||||||
|
|
|
@ -26,7 +26,6 @@ from torch import Tensor
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
import lightning.pytorch as pl
|
import lightning.pytorch as pl
|
||||||
from lightning.fabric.utilities.rank_zero import _get_rank
|
|
||||||
from lightning.pytorch.callbacks.callback import Callback
|
from lightning.pytorch.callbacks.callback import Callback
|
||||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||||
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_warn
|
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_warn
|
||||||
|
@ -265,12 +264,8 @@ class EarlyStopping(Callback):
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None:
|
def _log_info(trainer: "pl.Trainer", message: str, log_rank_zero_only: bool) -> None:
|
||||||
rank = _get_rank(
|
rank = trainer.global_rank if trainer.world_size > 1 else None
|
||||||
strategy=(trainer.strategy if trainer is not None else None), # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
if trainer is not None and trainer.world_size <= 1:
|
|
||||||
rank = None
|
|
||||||
message = rank_prefixed_message(message, rank)
|
message = rank_prefixed_message(message, rank)
|
||||||
if rank is None or not log_rank_zero_only or rank == 0:
|
if rank is None or not log_rank_zero_only or rank == 0:
|
||||||
log.info(message)
|
log.info(message)
|
||||||
|
|
|
@ -480,7 +480,6 @@ def test_early_stopping_squeezes():
|
||||||
es_mock.assert_called_once_with(torch.tensor(0))
|
es_mock.assert_called_once_with(torch.tensor(0))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("trainer", [Trainer(), None])
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("log_rank_zero_only", "world_size", "global_rank", "expected_log"),
|
("log_rank_zero_only", "world_size", "global_rank", "expected_log"),
|
||||||
[
|
[
|
||||||
|
@ -492,15 +491,11 @@ def test_early_stopping_squeezes():
|
||||||
(True, 2, 1, None),
|
(True, 2, 1, None),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_early_stopping_log_info(trainer, log_rank_zero_only, world_size, global_rank, expected_log):
|
def test_early_stopping_log_info(log_rank_zero_only, world_size, global_rank, expected_log):
|
||||||
"""Checks if log.info() gets called with expected message when used within EarlyStopping."""
|
"""Checks if log.info() gets called with expected message when used within EarlyStopping."""
|
||||||
# set the global_rank and world_size if trainer is not None
|
# set the global_rank and world_size if trainer is not None
|
||||||
# or else always expect the simple logging message
|
# or else always expect the simple logging message
|
||||||
if trainer:
|
trainer = Mock(global_rank=global_rank, world_size=world_size)
|
||||||
trainer.strategy.global_rank = global_rank
|
|
||||||
trainer.strategy.world_size = world_size
|
|
||||||
else:
|
|
||||||
expected_log = "bar"
|
|
||||||
|
|
||||||
with mock.patch("lightning.pytorch.callbacks.early_stopping.log.info") as log_mock:
|
with mock.patch("lightning.pytorch.callbacks.early_stopping.log.info") as log_mock:
|
||||||
EarlyStopping._log_info(trainer, "bar", log_rank_zero_only)
|
EarlyStopping._log_info(trainer, "bar", log_rank_zero_only)
|
||||||
|
|
Loading…
Reference in New Issue