diff --git a/CHANGELOG.md b/CHANGELOG.md index f6afaccf0e..4037dd3632 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -91,7 +91,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue where incorrect type warnings appear when the overridden `LightningLite.run` method accepts user-defined arguments ([#12629](https://github.com/PyTorchLightning/pytorch-lightning/pull/12629)) -- +- Fixed `rank_zero_only` decorator in LSF environments ([#12587](https://github.com/PyTorchLightning/pytorch-lightning/pull/12587)) - diff --git a/pytorch_lightning/utilities/rank_zero.py b/pytorch_lightning/utilities/rank_zero.py index 513798ff7a..b503fa3727 100644 --- a/pytorch_lightning/utilities/rank_zero.py +++ b/pytorch_lightning/utilities/rank_zero.py @@ -39,7 +39,7 @@ def rank_zero_only(fn: Callable) -> Callable: def _get_rank() -> int: # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, # therefore LOCAL_RANK needs to be checked first - rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID") + rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") for key in rank_keys: rank = os.environ.get(key) if rank is not None: diff --git a/tests/utilities/test_rank_zero.py b/tests/utilities/test_rank_zero.py index 15a55fdd87..ebc827cc46 100644 --- a/tests/utilities/test_rank_zero.py +++ b/tests/utilities/test_rank_zero.py @@ -18,7 +18,7 @@ from unittest import mock import pytest -@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}]) +@pytest.mark.parametrize("env_vars", [{"RANK": "0"}, {"SLURM_PROCID": "0"}, {"JSM_NAMESPACE_RANK": "0"}]) def test_rank_zero_known_cluster_envs(env_vars: Mapping[str, str]): """Test that SLURM environment variables are properly checked for rank_zero_only.""" from pytorch_lightning.utilities.rank_zero import _get_rank, rank_zero_only @@ -38,7 +38,9 @@ def test_rank_zero_known_cluster_envs(env_vars: Mapping[str, str]): assert x == 1 -@pytest.mark.parametrize("rank_key,rank", [("RANK", "1"), ("SLURM_PROCID", "2"), ("LOCAL_RANK", "3")]) +@pytest.mark.parametrize( + "rank_key,rank", [("RANK", "1"), ("SLURM_PROCID", "2"), ("LOCAL_RANK", "3"), ("JSM_NAMESPACE_RANK", "4")] +) def test_rank_zero_none_set(rank_key, rank): """Test that function is not called when rank environment variables are not global zero.""" @@ -58,9 +60,10 @@ def test_rank_zero_none_set(rank_key, rank): @pytest.mark.parametrize( "environ,expected_rank", [ - ({"SLURM_PROCID": "2"}, 2), - ({"SLURM_PROCID": "2", "LOCAL_RANK": "1"}, 1), - ({"SLURM_PROCID": "2", "LOCAL_RANK": "1", "RANK": "0"}, 0), + ({"JSM_NAMESPACE_RANK": "3"}, 3), + ({"JSM_NAMESPACE_RANK": "3", "SLURM_PROCID": "2"}, 2), + ({"JSM_NAMESPACE_RANK": "3", "SLURM_PROCID": "2", "LOCAL_RANK": "1"}, 1), + ({"JSM_NAMESPACE_RANK": "3", "SLURM_PROCID": "2", "LOCAL_RANK": "1", "RANK": "0"}, 0), ], ) def test_rank_zero_priority(environ, expected_rank):