Unify rank zero messaging utilities (#14116)

This commit is contained in:
Carlos Mocholí 2022-08-30 11:51:30 +02:00 committed by GitHub
parent 18e2a8eecd
commit 291267c3bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 56 additions and 79 deletions

View File

@ -28,7 +28,7 @@ from torch import Tensor
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import _get_rank, _rank_prefixed_message, rank_zero_warn
log = logging.getLogger(__name__)
@ -259,14 +259,9 @@ class EarlyStopping(Callback):
@staticmethod
def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None:
if trainer:
# ignore logging in non-zero ranks if log_rank_zero_only flag is enabled
if log_rank_zero_only and trainer.global_rank != 0:
return
# if world size is more than one then specify the rank of the process being logged
if trainer.world_size > 1:
log.info(f"[rank: {trainer.global_rank}] {message}")
return
# if above conditions don't meet and we have to log
log.info(message)
rank = _get_rank(trainer)
if trainer is not None and trainer.world_size <= 1:
rank = None
message = _rank_prefixed_message(message, rank)
if rank is None or not log_rank_zero_only or rank == 0:
log.info(message)

View File

@ -20,6 +20,8 @@ from functools import partial, wraps
from platform import python_version
from typing import Any, Callable, Optional, Union
import pytorch_lightning as pl
log = logging.getLogger(__name__)
@ -35,8 +37,9 @@ def rank_zero_only(fn: Callable) -> Callable:
return wrapped_fn
# TODO: this should be part of the cluster environment
def _get_rank() -> int:
def _get_rank(trainer: Optional["pl.Trainer"] = None) -> Optional[int]:
if trainer is not None:
return trainer.global_rank
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
# therefore LOCAL_RANK needs to be checked first
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
@ -44,11 +47,12 @@ def _get_rank() -> int:
rank = os.environ.get(key)
if rank is not None:
return int(rank)
return 0
# None to differentiate whether an environment variable was set at all
return None
# add the attribute to the function but don't overwrite in case Trainer has already set it
rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank())
rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank() or 0)
def _info(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None:
@ -97,3 +101,10 @@ class LightningDeprecationWarning(DeprecationWarning):
rank_zero_deprecation = partial(rank_zero_warn, category=LightningDeprecationWarning)
def _rank_prefixed_message(message: str, rank: Optional[int]) -> str:
if rank is not None:
# specify the rank of the process being logged
return f"[rank: {rank}] {message}"
return message

View File

@ -24,7 +24,7 @@ from typing import Any, Dict, Generator, Optional
import numpy as np
import torch
from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import _get_rank, _rank_prefixed_message, rank_zero_only, rank_zero_warn
log = logging.getLogger(__name__)
@ -66,7 +66,7 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
seed = _select_seed_randomly(min_seed_value, max_seed_value)
log.info(f"[rank: {_get_rank()}] Global seed set to {seed}")
log.info(_rank_prefixed_message(f"Global seed set to {seed}", _get_rank()))
os.environ["PL_GLOBAL_SEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)

View File

@ -472,9 +472,8 @@ def test_early_stopping_squeezes():
(True, 2, 1, None),
],
)
def test_early_stopping_log_info(tmpdir, trainer, log_rank_zero_only, world_size, global_rank, expected_log):
"""checks if log.info() gets called with expected message when used within EarlyStopping."""
def test_early_stopping_log_info(trainer, log_rank_zero_only, world_size, global_rank, expected_log):
"""Checks if log.info() gets called with expected message when used within EarlyStopping."""
# set the global_rank and world_size if trainer is not None
# or else always expect the simple logging message
if trainer:

View File

@ -11,20 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test deprecated functionality which will be removed in vX.Y.Z."""
import sys
from contextlib import contextmanager
from typing import Optional
from tests_pytorch.helpers.utils import no_warning_call
def _soft_unimport_module(str_module):
# once the module is imported e.g with parsing with pytest it lives in memory
if str_module in sys.modules:
del sys.modules[str_module]
@contextmanager
def no_deprecated_call(match: Optional[str] = None):
with no_warning_call(expected_warning=DeprecationWarning, match=match):

View File

@ -12,49 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Mapping
import sys
from unittest import mock
import pytest
@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}, {"JSM_NAMESPACE_RANK": "0"}])
def test_rank_zero_known_cluster_envs(env_vars: Mapping[str, str]):
"""Test that SLURM environment variables are properly checked for rank_zero_only."""
from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only
rank_zero_only.rank = _get_rank()
with mock.patch.dict(os.environ, env_vars):
from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only
rank_zero_only.rank = _get_rank()
@rank_zero_only
def foo(): # The return type is optional because on non-zero ranks it will not be called
return 1
x = foo()
assert x == 1
from pytorch_lightning.utilities.rank_zero import _get_rank, _rank_prefixed_message
@pytest.mark.parametrize(
"rank_key,rank", [("RANK", "1"), ("SLURM_PROCID", "2"), ("LOCAL_RANK", "3"), ("JSM_NAMESPACE_RANK", "4")]
"env_vars, expected",
[
({"RANK": "0"}, 1),
({"SLURM_PROCID": "0"}, 1),
({"LOCAL_RANK": "0"}, 1),
({"JSM_NAMESPACE_RANK": "0"}, 1),
({}, 1),
({"RANK": "1"}, None),
({"SLURM_PROCID": "2"}, None),
({"LOCAL_RANK": "3"}, None),
({"JSM_NAMESPACE_RANK": "4"}, None),
],
)
def test_rank_zero_none_set(rank_key, rank):
"""Test that function is not called when rank environment variables are not global zero."""
with mock.patch.dict(os.environ, {rank_key: rank}):
from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only
rank_zero_only.rank = _get_rank()
def test_rank_zero_known_environment_variables(env_vars, expected):
"""Test that rank environment variables are properly checked for rank_zero_only."""
with mock.patch.dict(os.environ, env_vars):
# force module reload to re-trigger the rank_zero_only.rank global computation
sys.modules.pop("pytorch_lightning.utilities.rank_zero", None)
from pytorch_lightning.utilities.rank_zero import rank_zero_only
@rank_zero_only
def foo():
return 1
x = foo()
assert x is None
assert foo() == expected
@pytest.mark.parametrize(
@ -69,6 +60,13 @@ def test_rank_zero_none_set(rank_key, rank):
def test_rank_zero_priority(environ, expected_rank):
"""Test the priority in which the rank gets determined when multiple environment variables are available."""
with mock.patch.dict(os.environ, environ):
from pytorch_lightning.utilities.rank_zero import _get_rank
assert _get_rank() == expected_rank
@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"RANK": "1"}, {"RANK": "4"}])
def test_rank_prefixed_message_with_env_vars(env_vars):
with mock.patch.dict(os.environ, env_vars, clear=True):
rank = _get_rank()
message = _rank_prefixed_message("bar", rank)
assert message == f"[rank: {rank}] bar"

View File

@ -1,8 +1,6 @@
import os
import random
from typing import Mapping
from unittest import mock
from unittest.mock import MagicMock
import numpy as np
import pytest
@ -116,19 +114,3 @@ def test_backward_compatibility_rng_states_dict():
assert "torch.cuda" in states
states.pop("torch.cuda")
_set_rng_states(states)
@mock.patch("pytorch_lightning.utilities.seed.log.info")
@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"RANK": "1"}, {"RANK": "4"}])
def test_seed_everything_log_info(log_mock: MagicMock, env_vars: Mapping[str, str]):
"""Test that log message prefix with correct rank info."""
with mock.patch.dict(os.environ, env_vars, clear=True):
from pytorch_lightning.utilities.rank_zero import _get_rank
rank = _get_rank()
seed_utils.seed_everything(123)
expected_log = f"[rank: {rank}] Global seed set to 123"
log_mock.assert_called_once_with(expected_log)