From 291267c3bff8054ec438960857c9f2fec1d54899 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 30 Aug 2022 11:51:30 +0200 Subject: [PATCH] Unify rank zero messaging utilities (#14116) --- .../callbacks/early_stopping.py | 19 +++--- src/pytorch_lightning/utilities/rank_zero.py | 19 ++++-- src/pytorch_lightning/utilities/seed.py | 4 +- .../callbacks/test_early_stopping.py | 5 +- .../tests_pytorch/deprecated_api/__init__.py | 8 --- .../tests_pytorch/utilities/test_rank_zero.py | 62 +++++++++---------- tests/tests_pytorch/utilities/test_seed.py | 18 ------ 7 files changed, 56 insertions(+), 79 deletions(-) diff --git a/src/pytorch_lightning/callbacks/early_stopping.py b/src/pytorch_lightning/callbacks/early_stopping.py index 72d8445d84..87585bb812 100644 --- a/src/pytorch_lightning/callbacks/early_stopping.py +++ b/src/pytorch_lightning/callbacks/early_stopping.py @@ -28,7 +28,7 @@ from torch import Tensor import pytorch_lightning as pl from pytorch_lightning.callbacks.callback import Callback from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.rank_zero import rank_zero_warn +from pytorch_lightning.utilities.rank_zero import _get_rank, _rank_prefixed_message, rank_zero_warn log = logging.getLogger(__name__) @@ -259,14 +259,9 @@ class EarlyStopping(Callback): @staticmethod def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None: - if trainer: - # ignore logging in non-zero ranks if log_rank_zero_only flag is enabled - if log_rank_zero_only and trainer.global_rank != 0: - return - # if world size is more than one then specify the rank of the process being logged - if trainer.world_size > 1: - log.info(f"[rank: {trainer.global_rank}] {message}") - return - - # if above conditions don't meet and we have to log - log.info(message) + rank = _get_rank(trainer) + if trainer is not None and trainer.world_size <= 1: + rank = None + message = _rank_prefixed_message(message, rank) + if rank is None or not log_rank_zero_only or rank == 0: + log.info(message) diff --git a/src/pytorch_lightning/utilities/rank_zero.py b/src/pytorch_lightning/utilities/rank_zero.py index e2292789c3..55bdc08930 100644 --- a/src/pytorch_lightning/utilities/rank_zero.py +++ b/src/pytorch_lightning/utilities/rank_zero.py @@ -20,6 +20,8 @@ from functools import partial, wraps from platform import python_version from typing import Any, Callable, Optional, Union +import pytorch_lightning as pl + log = logging.getLogger(__name__) @@ -35,8 +37,9 @@ def rank_zero_only(fn: Callable) -> Callable: return wrapped_fn -# TODO: this should be part of the cluster environment -def _get_rank() -> int: +def _get_rank(trainer: Optional["pl.Trainer"] = None) -> Optional[int]: + if trainer is not None: + return trainer.global_rank # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, # therefore LOCAL_RANK needs to be checked first rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") @@ -44,11 +47,12 @@ def _get_rank() -> int: rank = os.environ.get(key) if rank is not None: return int(rank) - return 0 + # None to differentiate whether an environment variable was set at all + return None # add the attribute to the function but don't overwrite in case Trainer has already set it -rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank()) +rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank() or 0) def _info(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None: @@ -97,3 +101,10 @@ class LightningDeprecationWarning(DeprecationWarning): rank_zero_deprecation = partial(rank_zero_warn, category=LightningDeprecationWarning) + + +def _rank_prefixed_message(message: str, rank: Optional[int]) -> str: + if rank is not None: + # specify the rank of the process being logged + return f"[rank: {rank}] {message}" + return message diff --git a/src/pytorch_lightning/utilities/seed.py b/src/pytorch_lightning/utilities/seed.py index 925337c784..cc9ff6673e 100644 --- a/src/pytorch_lightning/utilities/seed.py +++ b/src/pytorch_lightning/utilities/seed.py @@ -24,7 +24,7 @@ from typing import Any, Dict, Generator, Optional import numpy as np import torch -from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.rank_zero import _get_rank, _rank_prefixed_message, rank_zero_only, rank_zero_warn log = logging.getLogger(__name__) @@ -66,7 +66,7 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") seed = _select_seed_randomly(min_seed_value, max_seed_value) - log.info(f"[rank: {_get_rank()}] Global seed set to {seed}") + log.info(_rank_prefixed_message(f"Global seed set to {seed}", _get_rank())) os.environ["PL_GLOBAL_SEED"] = str(seed) random.seed(seed) np.random.seed(seed) diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index 458df2ea23..a3a98027cc 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -472,9 +472,8 @@ def test_early_stopping_squeezes(): (True, 2, 1, None), ], ) -def test_early_stopping_log_info(tmpdir, trainer, log_rank_zero_only, world_size, global_rank, expected_log): - """checks if log.info() gets called with expected message when used within EarlyStopping.""" - +def test_early_stopping_log_info(trainer, log_rank_zero_only, world_size, global_rank, expected_log): + """Checks if log.info() gets called with expected message when used within EarlyStopping.""" # set the global_rank and world_size if trainer is not None # or else always expect the simple logging message if trainer: diff --git a/tests/tests_pytorch/deprecated_api/__init__.py b/tests/tests_pytorch/deprecated_api/__init__.py index 611637d543..6e29ec8b3a 100644 --- a/tests/tests_pytorch/deprecated_api/__init__.py +++ b/tests/tests_pytorch/deprecated_api/__init__.py @@ -11,20 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Test deprecated functionality which will be removed in vX.Y.Z.""" -import sys from contextlib import contextmanager from typing import Optional from tests_pytorch.helpers.utils import no_warning_call -def _soft_unimport_module(str_module): - # once the module is imported e.g with parsing with pytest it lives in memory - if str_module in sys.modules: - del sys.modules[str_module] - - @contextmanager def no_deprecated_call(match: Optional[str] = None): with no_warning_call(expected_warning=DeprecationWarning, match=match): diff --git a/tests/tests_pytorch/utilities/test_rank_zero.py b/tests/tests_pytorch/utilities/test_rank_zero.py index ebc827cc46..76fa27926a 100644 --- a/tests/tests_pytorch/utilities/test_rank_zero.py +++ b/tests/tests_pytorch/utilities/test_rank_zero.py @@ -12,49 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from typing import Mapping +import sys from unittest import mock import pytest - -@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}, {"JSM_NAMESPACE_RANK": "0"}]) -def test_rank_zero_known_cluster_envs(env_vars: Mapping[str, str]): - """Test that SLURM environment variables are properly checked for rank_zero_only.""" - from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only - - rank_zero_only.rank = _get_rank() - - with mock.patch.dict(os.environ, env_vars): - from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only - - rank_zero_only.rank = _get_rank() - - @rank_zero_only - def foo(): # The return type is optional because on non-zero ranks it will not be called - return 1 - - x = foo() - assert x == 1 +from pytorch_lightning.utilities.rank_zero import _get_rank, _rank_prefixed_message @pytest.mark.parametrize( - "rank_key,rank", [("RANK", "1"), ("SLURM_PROCID", "2"), ("LOCAL_RANK", "3"), ("JSM_NAMESPACE_RANK", "4")] + "env_vars, expected", + [ + ({"RANK": "0"}, 1), + ({"SLURM_PROCID": "0"}, 1), + ({"LOCAL_RANK": "0"}, 1), + ({"JSM_NAMESPACE_RANK": "0"}, 1), + ({}, 1), + ({"RANK": "1"}, None), + ({"SLURM_PROCID": "2"}, None), + ({"LOCAL_RANK": "3"}, None), + ({"JSM_NAMESPACE_RANK": "4"}, None), + ], ) -def test_rank_zero_none_set(rank_key, rank): - """Test that function is not called when rank environment variables are not global zero.""" - - with mock.patch.dict(os.environ, {rank_key: rank}): - from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only - - rank_zero_only.rank = _get_rank() +def test_rank_zero_known_environment_variables(env_vars, expected): + """Test that rank environment variables are properly checked for rank_zero_only.""" + with mock.patch.dict(os.environ, env_vars): + # force module reload to re-trigger the rank_zero_only.rank global computation + sys.modules.pop("pytorch_lightning.utilities.rank_zero", None) + from pytorch_lightning.utilities.rank_zero import rank_zero_only @rank_zero_only def foo(): return 1 - x = foo() - assert x is None + assert foo() == expected @pytest.mark.parametrize( @@ -69,6 +60,13 @@ def test_rank_zero_none_set(rank_key, rank): def test_rank_zero_priority(environ, expected_rank): """Test the priority in which the rank gets determined when multiple environment variables are available.""" with mock.patch.dict(os.environ, environ): - from pytorch_lightning.utilities.rank_zero import _get_rank - assert _get_rank() == expected_rank + + +@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"RANK": "1"}, {"RANK": "4"}]) +def test_rank_prefixed_message_with_env_vars(env_vars): + with mock.patch.dict(os.environ, env_vars, clear=True): + rank = _get_rank() + message = _rank_prefixed_message("bar", rank) + + assert message == f"[rank: {rank}] bar" diff --git a/tests/tests_pytorch/utilities/test_seed.py b/tests/tests_pytorch/utilities/test_seed.py index c8df824e93..2c89883e3c 100644 --- a/tests/tests_pytorch/utilities/test_seed.py +++ b/tests/tests_pytorch/utilities/test_seed.py @@ -1,8 +1,6 @@ import os import random -from typing import Mapping from unittest import mock -from unittest.mock import MagicMock import numpy as np import pytest @@ -116,19 +114,3 @@ def test_backward_compatibility_rng_states_dict(): assert "torch.cuda" in states states.pop("torch.cuda") _set_rng_states(states) - - -@mock.patch("pytorch_lightning.utilities.seed.log.info") -@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"RANK": "1"}, {"RANK": "4"}]) -def test_seed_everything_log_info(log_mock: MagicMock, env_vars: Mapping[str, str]): - """Test that log message prefix with correct rank info.""" - with mock.patch.dict(os.environ, env_vars, clear=True): - from pytorch_lightning.utilities.rank_zero import _get_rank - - rank = _get_rank() - - seed_utils.seed_everything(123) - - expected_log = f"[rank: {rank}] Global seed set to 123" - - log_mock.assert_called_once_with(expected_log)