From d7af8ce2a5fd0aa3c0d69fedb5fbabd0e1be09f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 30 Sep 2022 17:40:43 +0200 Subject: [PATCH] Simplify root node resolution for SLURM environment (#14912) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Seppo Enarvi Co-authored-by: Carlos MocholĂ­ --- .../plugins/environments/slurm_environment.py | 30 ++++++++----------- .../environments/test_slurm_environment.py | 14 ++++++++- tests/tests_pytorch/models/test_amp.py | 7 ----- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/lightning_lite/plugins/environments/slurm_environment.py b/src/lightning_lite/plugins/environments/slurm_environment.py index 7b7b3e5fa6..074f302eac 100644 --- a/src/lightning_lite/plugins/environments/slurm_environment.py +++ b/src/lightning_lite/plugins/environments/slurm_environment.py @@ -47,14 +47,8 @@ class SLURMEnvironment(ClusterEnvironment): @property def main_address(self) -> str: - # figure out the root node addr - slurm_nodelist = os.environ.get("SLURM_NODELIST") - if slurm_nodelist: - root_node = slurm_nodelist.split(" ")[0].split(",")[0] - else: - root_node = "127.0.0.1" - - root_node = self.resolve_root_node_address(root_node) + nodelist = os.environ.get("SLURM_NODELIST", "127.0.0.1") + root_node = self.resolve_root_node_address(nodelist) os.environ["MASTER_ADDR"] = root_node log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") return root_node @@ -134,14 +128,16 @@ class SLURMEnvironment(ClusterEnvironment): def node_rank(self) -> int: return int(os.environ["SLURM_NODEID"]) - def resolve_root_node_address(self, root_node: str) -> str: - if "[" in root_node: - name, numbers = root_node.split("[", maxsplit=1) - number = numbers.split(",", maxsplit=1)[0] - if "-" in number: - number = number.split("-")[0] + @staticmethod + def resolve_root_node_address(nodes: str) -> str: + """The node selection format in SLURM supports several formats. - number = re.sub("[^0-9]", "", number) - root_node = name + number + This function selects the first host name from - return root_node + - a space-separated list of host names, e.g., 'host0 host1 host3' yields 'host0' as the root + - a comma-separated list of host names, e.g., 'host0,host1,host3' yields 'host0' as the root + - the range notation with brackets, e.g., 'host[5-9]' yields 'host5' as the root + """ + nodes = re.sub(r"\[(.*?)[,-].*\]", "\\1", nodes) # Take the first node of every node range + nodes = re.sub(r"\[(.*?)\]", "\\1", nodes) # handle special case where node range is single number + return nodes.split(" ")[0].split(",")[0] diff --git a/tests/tests_lite/plugins/environments/test_slurm_environment.py b/tests/tests_lite/plugins/environments/test_slurm_environment.py index 441f996127..97c6d10494 100644 --- a/tests/tests_lite/plugins/environments/test_slurm_environment.py +++ b/tests/tests_lite/plugins/environments/test_slurm_environment.py @@ -81,7 +81,19 @@ def test_attributes_from_environment_variables(caplog): @pytest.mark.parametrize( "slurm_node_list,expected", - [("alpha,beta,gamma", "alpha"), ("alpha beta gamma", "alpha"), ("1.2.3.[100-110]", "1.2.3.100")], + [ + ("127.0.0.1", "127.0.0.1"), + ("alpha", "alpha"), + ("alpha,beta,gamma", "alpha"), + ("alpha beta gamma", "alpha"), + ("1.2.3.[100-110]", "1.2.3.100"), + ("1.2.3.[089, 100-110]", "1.2.3.089"), + ("host[22]", "host22"), + ("host[1,5-9]", "host1"), + ("host[5-9,1]", "host5"), + ("alpha, host[5-9], gamma", "alpha"), + ("alpha[3,1], beta", "alpha3"), + ], ) def test_main_address_from_slurm_node_list(slurm_node_list, expected): """Test extracting the main node from different formats for the SLURM_NODELIST.""" diff --git a/tests/tests_pytorch/models/test_amp.py b/tests/tests_pytorch/models/test_amp.py index ce1ed69d50..74bd4c20ab 100644 --- a/tests/tests_pytorch/models/test_amp.py +++ b/tests/tests_pytorch/models/test_amp.py @@ -157,14 +157,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): logger=logger, ) trainer.fit(model) - - # test root model address assert isinstance(trainer.strategy.cluster_environment, SLURMEnvironment) - assert trainer.strategy.cluster_environment.resolve_root_node_address("abc") == "abc" - assert trainer.strategy.cluster_environment.resolve_root_node_address("abc[23]") == "abc23" - assert trainer.strategy.cluster_environment.resolve_root_node_address("abc[23-24]") == "abc23" - generated = trainer.strategy.cluster_environment.resolve_root_node_address("abc[23-24, 45-40, 40]") - assert generated == "abc23" @mock.patch("pytorch_lightning.plugins.precision.apex_amp.ApexMixedPrecisionPlugin.backward")