Simplify root node resolution for SLURM environment (#14912)

Co-authored-by: Seppo Enarvi <seppo.git@marjaniemi.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Adrian Wälchli 2022-09-30 17:40:43 +02:00 committed by GitHub
parent cd9247a782
commit d7af8ce2a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 25 deletions

View File

@ -47,14 +47,8 @@ class SLURMEnvironment(ClusterEnvironment):
@property
def main_address(self) -> str:
# figure out the root node addr
slurm_nodelist = os.environ.get("SLURM_NODELIST")
if slurm_nodelist:
root_node = slurm_nodelist.split(" ")[0].split(",")[0]
else:
root_node = "127.0.0.1"
root_node = self.resolve_root_node_address(root_node)
nodelist = os.environ.get("SLURM_NODELIST", "127.0.0.1")
root_node = self.resolve_root_node_address(nodelist)
os.environ["MASTER_ADDR"] = root_node
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
return root_node
@ -134,14 +128,16 @@ class SLURMEnvironment(ClusterEnvironment):
def node_rank(self) -> int:
return int(os.environ["SLURM_NODEID"])
def resolve_root_node_address(self, root_node: str) -> str:
if "[" in root_node:
name, numbers = root_node.split("[", maxsplit=1)
number = numbers.split(",", maxsplit=1)[0]
if "-" in number:
number = number.split("-")[0]
@staticmethod
def resolve_root_node_address(nodes: str) -> str:
"""The node selection format in SLURM supports several formats.
number = re.sub("[^0-9]", "", number)
root_node = name + number
This function selects the first host name from
return root_node
- a space-separated list of host names, e.g., 'host0 host1 host3' yields 'host0' as the root
- a comma-separated list of host names, e.g., 'host0,host1,host3' yields 'host0' as the root
- the range notation with brackets, e.g., 'host[5-9]' yields 'host5' as the root
"""
nodes = re.sub(r"\[(.*?)[,-].*\]", "\\1", nodes) # Take the first node of every node range
nodes = re.sub(r"\[(.*?)\]", "\\1", nodes) # handle special case where node range is single number
return nodes.split(" ")[0].split(",")[0]

View File

@ -81,7 +81,19 @@ def test_attributes_from_environment_variables(caplog):
@pytest.mark.parametrize(
"slurm_node_list,expected",
[("alpha,beta,gamma", "alpha"), ("alpha beta gamma", "alpha"), ("1.2.3.[100-110]", "1.2.3.100")],
[
("127.0.0.1", "127.0.0.1"),
("alpha", "alpha"),
("alpha,beta,gamma", "alpha"),
("alpha beta gamma", "alpha"),
("1.2.3.[100-110]", "1.2.3.100"),
("1.2.3.[089, 100-110]", "1.2.3.089"),
("host[22]", "host22"),
("host[1,5-9]", "host1"),
("host[5-9,1]", "host5"),
("alpha, host[5-9], gamma", "alpha"),
("alpha[3,1], beta", "alpha3"),
],
)
def test_main_address_from_slurm_node_list(slurm_node_list, expected):
"""Test extracting the main node from different formats for the SLURM_NODELIST."""

View File

@ -157,14 +157,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
logger=logger,
)
trainer.fit(model)
# test root model address
assert isinstance(trainer.strategy.cluster_environment, SLURMEnvironment)
assert trainer.strategy.cluster_environment.resolve_root_node_address("abc") == "abc"
assert trainer.strategy.cluster_environment.resolve_root_node_address("abc[23]") == "abc23"
assert trainer.strategy.cluster_environment.resolve_root_node_address("abc[23-24]") == "abc23"
generated = trainer.strategy.cluster_environment.resolve_root_node_address("abc[23-24, 45-40, 40]")
assert generated == "abc23"
@mock.patch("pytorch_lightning.plugins.precision.apex_amp.ApexMixedPrecisionPlugin.backward")