Update `rank_zero_only` decorator for LSF environments (#12587)

This commit is contained in:
Adrian Wälchli 2022-04-07 12:46:55 +02:00 committed by GitHub
parent 800580a131
commit 6aa8e26a4e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 7 deletions

View File

@ -91,7 +91,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue where incorrect type warnings appear when the overridden `LightningLite.run` method accepts user-defined arguments ([#12629](https://github.com/PyTorchLightning/pytorch-lightning/pull/12629))
-
- Fixed `rank_zero_only` decorator in LSF environments ([#12587](https://github.com/PyTorchLightning/pytorch-lightning/pull/12587))
-

View File

@ -39,7 +39,7 @@ def rank_zero_only(fn: Callable) -> Callable:
def _get_rank() -> int:
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
# therefore LOCAL_RANK needs to be checked first
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID")
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
for key in rank_keys:
rank = os.environ.get(key)
if rank is not None:

View File

@ -18,7 +18,7 @@ from unittest import mock
import pytest
@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}])
@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}, {"JSM_NAMESPACE_RANK": "0"}])
def test_rank_zero_known_cluster_envs(env_vars: Mapping[str, str]):
"""Test that SLURM environment variables are properly checked for rank_zero_only."""
from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only
@ -38,7 +38,9 @@ def test_rank_zero_known_cluster_envs(env_vars: Mapping[str, str]):
assert x == 1
@pytest.mark.parametrize("rank_key,rank", [("RANK", "1"), ("SLURM_PROCID", "2"), ("LOCAL_RANK", "3")])
@pytest.mark.parametrize(
"rank_key,rank", [("RANK", "1"), ("SLURM_PROCID", "2"), ("LOCAL_RANK", "3"), ("JSM_NAMESPACE_RANK", "4")]
)
def test_rank_zero_none_set(rank_key, rank):
"""Test that function is not called when rank environment variables are not global zero."""
@ -58,9 +60,10 @@ def test_rank_zero_none_set(rank_key, rank):
@pytest.mark.parametrize(
"environ,expected_rank",
[
({"SLURM_PROCID": "2"}, 2),
({"SLURM_PROCID": "2", "LOCAL_RANK": "1"}, 1),
({"SLURM_PROCID": "2", "LOCAL_RANK": "1", "RANK": "0"}, 0),
({"JSM_NAMESPACE_RANK": "3"}, 3),
({"JSM_NAMESPACE_RANK": "3", "SLURM_PROCID": "2"}, 2),
({"JSM_NAMESPACE_RANK": "3", "SLURM_PROCID": "2", "LOCAL_RANK": "1"}, 1),
({"JSM_NAMESPACE_RANK": "3", "SLURM_PROCID": "2", "LOCAL_RANK": "1", "RANK": "0"}, 0),
],
)
def test_rank_zero_priority(environ, expected_rank):