Unify rank zero messaging utilities (#14116)
This commit is contained in:
parent
18e2a8eecd
commit
291267c3bf
|
@ -28,7 +28,7 @@ from torch import Tensor
|
|||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
|
||||
from pytorch_lightning.utilities.rank_zero import _get_rank, _rank_prefixed_message, rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -259,14 +259,9 @@ class EarlyStopping(Callback):
|
|||
|
||||
@staticmethod
|
||||
def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None:
|
||||
if trainer:
|
||||
# ignore logging in non-zero ranks if log_rank_zero_only flag is enabled
|
||||
if log_rank_zero_only and trainer.global_rank != 0:
|
||||
return
|
||||
# if world size is more than one then specify the rank of the process being logged
|
||||
if trainer.world_size > 1:
|
||||
log.info(f"[rank: {trainer.global_rank}] {message}")
|
||||
return
|
||||
|
||||
# if above conditions don't meet and we have to log
|
||||
log.info(message)
|
||||
rank = _get_rank(trainer)
|
||||
if trainer is not None and trainer.world_size <= 1:
|
||||
rank = None
|
||||
message = _rank_prefixed_message(message, rank)
|
||||
if rank is None or not log_rank_zero_only or rank == 0:
|
||||
log.info(message)
|
||||
|
|
|
@ -20,6 +20,8 @@ from functools import partial, wraps
|
|||
from platform import python_version
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -35,8 +37,9 @@ def rank_zero_only(fn: Callable) -> Callable:
|
|||
return wrapped_fn
|
||||
|
||||
|
||||
# TODO: this should be part of the cluster environment
|
||||
def _get_rank() -> int:
|
||||
def _get_rank(trainer: Optional["pl.Trainer"] = None) -> Optional[int]:
|
||||
if trainer is not None:
|
||||
return trainer.global_rank
|
||||
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
|
||||
# therefore LOCAL_RANK needs to be checked first
|
||||
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
|
||||
|
@ -44,11 +47,12 @@ def _get_rank() -> int:
|
|||
rank = os.environ.get(key)
|
||||
if rank is not None:
|
||||
return int(rank)
|
||||
return 0
|
||||
# None to differentiate whether an environment variable was set at all
|
||||
return None
|
||||
|
||||
|
||||
# add the attribute to the function but don't overwrite in case Trainer has already set it
|
||||
rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank())
|
||||
rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank() or 0)
|
||||
|
||||
|
||||
def _info(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None:
|
||||
|
@ -97,3 +101,10 @@ class LightningDeprecationWarning(DeprecationWarning):
|
|||
|
||||
|
||||
rank_zero_deprecation = partial(rank_zero_warn, category=LightningDeprecationWarning)
|
||||
|
||||
|
||||
def _rank_prefixed_message(message: str, rank: Optional[int]) -> str:
|
||||
if rank is not None:
|
||||
# specify the rank of the process being logged
|
||||
return f"[rank: {rank}] {message}"
|
||||
return message
|
||||
|
|
|
@ -24,7 +24,7 @@ from typing import Any, Dict, Generator, Optional
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.rank_zero import _get_rank, _rank_prefixed_message, rank_zero_only, rank_zero_warn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
@ -66,7 +66,7 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
|
|||
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
|
||||
seed = _select_seed_randomly(min_seed_value, max_seed_value)
|
||||
|
||||
log.info(f"[rank: {_get_rank()}] Global seed set to {seed}")
|
||||
log.info(_rank_prefixed_message(f"Global seed set to {seed}", _get_rank()))
|
||||
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
|
|
@ -472,9 +472,8 @@ def test_early_stopping_squeezes():
|
|||
(True, 2, 1, None),
|
||||
],
|
||||
)
|
||||
def test_early_stopping_log_info(tmpdir, trainer, log_rank_zero_only, world_size, global_rank, expected_log):
|
||||
"""checks if log.info() gets called with expected message when used within EarlyStopping."""
|
||||
|
||||
def test_early_stopping_log_info(trainer, log_rank_zero_only, world_size, global_rank, expected_log):
|
||||
"""Checks if log.info() gets called with expected message when used within EarlyStopping."""
|
||||
# set the global_rank and world_size if trainer is not None
|
||||
# or else always expect the simple logging message
|
||||
if trainer:
|
||||
|
|
|
@ -11,20 +11,12 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Test deprecated functionality which will be removed in vX.Y.Z."""
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
|
||||
from tests_pytorch.helpers.utils import no_warning_call
|
||||
|
||||
|
||||
def _soft_unimport_module(str_module):
|
||||
# once the module is imported e.g with parsing with pytest it lives in memory
|
||||
if str_module in sys.modules:
|
||||
del sys.modules[str_module]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def no_deprecated_call(match: Optional[str] = None):
|
||||
with no_warning_call(expected_warning=DeprecationWarning, match=match):
|
||||
|
|
|
@ -12,49 +12,40 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Mapping
|
||||
import sys
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}, {"JSM_NAMESPACE_RANK": "0"}])
|
||||
def test_rank_zero_known_cluster_envs(env_vars: Mapping[str, str]):
|
||||
"""Test that SLURM environment variables are properly checked for rank_zero_only."""
|
||||
from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only
|
||||
|
||||
rank_zero_only.rank = _get_rank()
|
||||
|
||||
with mock.patch.dict(os.environ, env_vars):
|
||||
from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only
|
||||
|
||||
rank_zero_only.rank = _get_rank()
|
||||
|
||||
@rank_zero_only
|
||||
def foo(): # The return type is optional because on non-zero ranks it will not be called
|
||||
return 1
|
||||
|
||||
x = foo()
|
||||
assert x == 1
|
||||
from pytorch_lightning.utilities.rank_zero import _get_rank, _rank_prefixed_message
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"rank_key,rank", [("RANK", "1"), ("SLURM_PROCID", "2"), ("LOCAL_RANK", "3"), ("JSM_NAMESPACE_RANK", "4")]
|
||||
"env_vars, expected",
|
||||
[
|
||||
({"RANK": "0"}, 1),
|
||||
({"SLURM_PROCID": "0"}, 1),
|
||||
({"LOCAL_RANK": "0"}, 1),
|
||||
({"JSM_NAMESPACE_RANK": "0"}, 1),
|
||||
({}, 1),
|
||||
({"RANK": "1"}, None),
|
||||
({"SLURM_PROCID": "2"}, None),
|
||||
({"LOCAL_RANK": "3"}, None),
|
||||
({"JSM_NAMESPACE_RANK": "4"}, None),
|
||||
],
|
||||
)
|
||||
def test_rank_zero_none_set(rank_key, rank):
|
||||
"""Test that function is not called when rank environment variables are not global zero."""
|
||||
|
||||
with mock.patch.dict(os.environ, {rank_key: rank}):
|
||||
from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only
|
||||
|
||||
rank_zero_only.rank = _get_rank()
|
||||
def test_rank_zero_known_environment_variables(env_vars, expected):
|
||||
"""Test that rank environment variables are properly checked for rank_zero_only."""
|
||||
with mock.patch.dict(os.environ, env_vars):
|
||||
# force module reload to re-trigger the rank_zero_only.rank global computation
|
||||
sys.modules.pop("pytorch_lightning.utilities.rank_zero", None)
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
|
||||
@rank_zero_only
|
||||
def foo():
|
||||
return 1
|
||||
|
||||
x = foo()
|
||||
assert x is None
|
||||
assert foo() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -69,6 +60,13 @@ def test_rank_zero_none_set(rank_key, rank):
|
|||
def test_rank_zero_priority(environ, expected_rank):
|
||||
"""Test the priority in which the rank gets determined when multiple environment variables are available."""
|
||||
with mock.patch.dict(os.environ, environ):
|
||||
from pytorch_lightning.utilities.rank_zero import _get_rank
|
||||
|
||||
assert _get_rank() == expected_rank
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"RANK": "1"}, {"RANK": "4"}])
|
||||
def test_rank_prefixed_message_with_env_vars(env_vars):
|
||||
with mock.patch.dict(os.environ, env_vars, clear=True):
|
||||
rank = _get_rank()
|
||||
message = _rank_prefixed_message("bar", rank)
|
||||
|
||||
assert message == f"[rank: {rank}] bar"
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
import os
|
||||
import random
|
||||
from typing import Mapping
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
@ -116,19 +114,3 @@ def test_backward_compatibility_rng_states_dict():
|
|||
assert "torch.cuda" in states
|
||||
states.pop("torch.cuda")
|
||||
_set_rng_states(states)
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.utilities.seed.log.info")
|
||||
@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"RANK": "1"}, {"RANK": "4"}])
|
||||
def test_seed_everything_log_info(log_mock: MagicMock, env_vars: Mapping[str, str]):
|
||||
"""Test that log message prefix with correct rank info."""
|
||||
with mock.patch.dict(os.environ, env_vars, clear=True):
|
||||
from pytorch_lightning.utilities.rank_zero import _get_rank
|
||||
|
||||
rank = _get_rank()
|
||||
|
||||
seed_utils.seed_everything(123)
|
||||
|
||||
expected_log = f"[rank: {rank}] Global seed set to 123"
|
||||
|
||||
log_mock.assert_called_once_with(expected_log)
|
||||
|
|
Loading…
Reference in New Issue