Simplify root node resolution for SLURM environment (#14912)
Co-authored-by: Seppo Enarvi <seppo.git@marjaniemi.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
cd9247a782
commit
d7af8ce2a5
|
@ -47,14 +47,8 @@ class SLURMEnvironment(ClusterEnvironment):
|
|||
|
||||
@property
|
||||
def main_address(self) -> str:
|
||||
# figure out the root node addr
|
||||
slurm_nodelist = os.environ.get("SLURM_NODELIST")
|
||||
if slurm_nodelist:
|
||||
root_node = slurm_nodelist.split(" ")[0].split(",")[0]
|
||||
else:
|
||||
root_node = "127.0.0.1"
|
||||
|
||||
root_node = self.resolve_root_node_address(root_node)
|
||||
nodelist = os.environ.get("SLURM_NODELIST", "127.0.0.1")
|
||||
root_node = self.resolve_root_node_address(nodelist)
|
||||
os.environ["MASTER_ADDR"] = root_node
|
||||
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
|
||||
return root_node
|
||||
|
@ -134,14 +128,16 @@ class SLURMEnvironment(ClusterEnvironment):
|
|||
def node_rank(self) -> int:
|
||||
return int(os.environ["SLURM_NODEID"])
|
||||
|
||||
def resolve_root_node_address(self, root_node: str) -> str:
|
||||
if "[" in root_node:
|
||||
name, numbers = root_node.split("[", maxsplit=1)
|
||||
number = numbers.split(",", maxsplit=1)[0]
|
||||
if "-" in number:
|
||||
number = number.split("-")[0]
|
||||
@staticmethod
|
||||
def resolve_root_node_address(nodes: str) -> str:
|
||||
"""The node selection format in SLURM supports several formats.
|
||||
|
||||
number = re.sub("[^0-9]", "", number)
|
||||
root_node = name + number
|
||||
This function selects the first host name from
|
||||
|
||||
return root_node
|
||||
- a space-separated list of host names, e.g., 'host0 host1 host3' yields 'host0' as the root
|
||||
- a comma-separated list of host names, e.g., 'host0,host1,host3' yields 'host0' as the root
|
||||
- the range notation with brackets, e.g., 'host[5-9]' yields 'host5' as the root
|
||||
"""
|
||||
nodes = re.sub(r"\[(.*?)[,-].*\]", "\\1", nodes) # Take the first node of every node range
|
||||
nodes = re.sub(r"\[(.*?)\]", "\\1", nodes) # handle special case where node range is single number
|
||||
return nodes.split(" ")[0].split(",")[0]
|
||||
|
|
|
@ -81,7 +81,19 @@ def test_attributes_from_environment_variables(caplog):
|
|||
|
||||
@pytest.mark.parametrize(
|
||||
"slurm_node_list,expected",
|
||||
[("alpha,beta,gamma", "alpha"), ("alpha beta gamma", "alpha"), ("1.2.3.[100-110]", "1.2.3.100")],
|
||||
[
|
||||
("127.0.0.1", "127.0.0.1"),
|
||||
("alpha", "alpha"),
|
||||
("alpha,beta,gamma", "alpha"),
|
||||
("alpha beta gamma", "alpha"),
|
||||
("1.2.3.[100-110]", "1.2.3.100"),
|
||||
("1.2.3.[089, 100-110]", "1.2.3.089"),
|
||||
("host[22]", "host22"),
|
||||
("host[1,5-9]", "host1"),
|
||||
("host[5-9,1]", "host5"),
|
||||
("alpha, host[5-9], gamma", "alpha"),
|
||||
("alpha[3,1], beta", "alpha3"),
|
||||
],
|
||||
)
|
||||
def test_main_address_from_slurm_node_list(slurm_node_list, expected):
|
||||
"""Test extracting the main node from different formats for the SLURM_NODELIST."""
|
||||
|
|
|
@ -157,14 +157,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir):
|
|||
logger=logger,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
# test root model address
|
||||
assert isinstance(trainer.strategy.cluster_environment, SLURMEnvironment)
|
||||
assert trainer.strategy.cluster_environment.resolve_root_node_address("abc") == "abc"
|
||||
assert trainer.strategy.cluster_environment.resolve_root_node_address("abc[23]") == "abc23"
|
||||
assert trainer.strategy.cluster_environment.resolve_root_node_address("abc[23-24]") == "abc23"
|
||||
generated = trainer.strategy.cluster_environment.resolve_root_node_address("abc[23-24, 45-40, 40]")
|
||||
assert generated == "abc23"
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.plugins.precision.apex_amp.ApexMixedPrecisionPlugin.backward")
|
||||
|
|
Loading…
Reference in New Issue