Update `rank_zero_only` decorator for LSF environments (#12587)
This commit is contained in:
parent
800580a131
commit
6aa8e26a4e
|
@ -91,7 +91,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed an issue where incorrect type warnings appear when the overridden `LightningLite.run` method accepts user-defined arguments ([#12629](https://github.com/PyTorchLightning/pytorch-lightning/pull/12629))
|
||||
|
||||
|
||||
-
|
||||
- Fixed `rank_zero_only` decorator in LSF environments ([#12587](https://github.com/PyTorchLightning/pytorch-lightning/pull/12587))
|
||||
|
||||
|
||||
-
|
||||
|
|
|
@ -39,7 +39,7 @@ def rank_zero_only(fn: Callable) -> Callable:
|
|||
def _get_rank() -> int:
|
||||
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
|
||||
# therefore LOCAL_RANK needs to be checked first
|
||||
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID")
|
||||
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
|
||||
for key in rank_keys:
|
||||
rank = os.environ.get(key)
|
||||
if rank is not None:
|
||||
|
|
|
@ -18,7 +18,7 @@ from unittest import mock
|
|||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}])
|
||||
@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}, {"JSM_NAMESPACE_RANK": "0"}])
|
||||
def test_rank_zero_known_cluster_envs(env_vars: Mapping[str, str]):
|
||||
"""Test that SLURM environment variables are properly checked for rank_zero_only."""
|
||||
from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only
|
||||
|
@ -38,7 +38,9 @@ def test_rank_zero_known_cluster_envs(env_vars: Mapping[str, str]):
|
|||
assert x == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rank_key,rank", [("RANK", "1"), ("SLURM_PROCID", "2"), ("LOCAL_RANK", "3")])
|
||||
@pytest.mark.parametrize(
|
||||
"rank_key,rank", [("RANK", "1"), ("SLURM_PROCID", "2"), ("LOCAL_RANK", "3"), ("JSM_NAMESPACE_RANK", "4")]
|
||||
)
|
||||
def test_rank_zero_none_set(rank_key, rank):
|
||||
"""Test that function is not called when rank environment variables are not global zero."""
|
||||
|
||||
|
@ -58,9 +60,10 @@ def test_rank_zero_none_set(rank_key, rank):
|
|||
@pytest.mark.parametrize(
|
||||
"environ,expected_rank",
|
||||
[
|
||||
({"SLURM_PROCID": "2"}, 2),
|
||||
({"SLURM_PROCID": "2", "LOCAL_RANK": "1"}, 1),
|
||||
({"SLURM_PROCID": "2", "LOCAL_RANK": "1", "RANK": "0"}, 0),
|
||||
({"JSM_NAMESPACE_RANK": "3"}, 3),
|
||||
({"JSM_NAMESPACE_RANK": "3", "SLURM_PROCID": "2"}, 2),
|
||||
({"JSM_NAMESPACE_RANK": "3", "SLURM_PROCID": "2", "LOCAL_RANK": "1"}, 1),
|
||||
({"JSM_NAMESPACE_RANK": "3", "SLURM_PROCID": "2", "LOCAL_RANK": "1", "RANK": "0"}, 0),
|
||||
],
|
||||
)
|
||||
def test_rank_zero_priority(environ, expected_rank):
|
||||
|
|
Loading…
Reference in New Issue