Modify LSFEnvironment to use more reliable environment variable (#10825)

Co-authored-by: thomas chaton <thomas@grid.ai>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Andrew Tritt 2022-01-05 04:45:25 -08:00 committed by GitHub
parent 93223ff5ce
commit dbf1acd5a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 154 additions and 89 deletions

View File

@ -140,6 +140,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed duplicated file extension when uploading model checkpoints with `NeptuneLogger` ([#11015](https://github.com/PyTorchLightning/pytorch-lightning/pull/11015))
- Changed `LSFEnvironment` to use `LSB_DJOB_RANKFILE` environment variable instead of `LSB_HOSTS` for determining node rank and main address ([#10825](https://github.com/PyTorchLightning/pytorch-lightning/pull/10825))
- Removed `__getstate__` and `__setstate__` of `RichProgressBar` ([#11100](https://github.com/PyTorchLightning/pytorch-lightning/pull/11100))

View File

@ -19,6 +19,7 @@ from typing import Dict, List
from pytorch_lightning import _logger as log
from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.cloud_io import get_filesystem
class LSFEnvironment(ClusterEnvironment):
@ -27,19 +28,22 @@ class LSFEnvironment(ClusterEnvironment):
It is expected that any execution using this ClusterEnvironment was executed
using the Job Step Manager i.e. ``jsrun``.
This plugin expects the following environment variables.
This plugin expects the following environment variables:
LSB_JOBID:
The LSF assigned job ID
``LSB_JOBID``
The LSF assigned job ID
LSB_HOSTS:
The hosts used in the job. This string is expected to have the format "batch <rank_0_host> ...."
``LSB_DJOB_RANKFILE``
The OpenMPI compatibile rank file for the LSF job
JSM_NAMESPACE_LOCAL_RANK:
The node local rank for the task. This environment variable is set by jsrun
``JSM_NAMESPACE_LOCAL_RANK``
The node local rank for the task. This environment variable is set by ``jsrun``
JSM_NAMESPACE_SIZE:
The world size for the task. This environment variable is set by jsrun
``JSM_NAMESPACE_SIZE``
The world size for the task. This environment variable is set by ``jsrun``
``JSM_NAMESPACE_RANK``
The global rank for the task. This environment variable is set by ``jsrun``
"""
def __init__(self) -> None:
@ -52,37 +56,45 @@ class LSFEnvironment(ClusterEnvironment):
)
self._main_address = self._get_main_address()
self._main_port = self._get_main_port()
log.debug(f"MASTER_ADDR: {self._main_address}")
log.debug(f"MASTER_PORT: {self._main_port}")
self._node_rank = self._get_node_rank()
self._set_init_progress_group_env_vars()
def _set_init_progress_group_env_vars(self) -> None:
# set environment variables needed for initializing torch distributed process group
os.environ["MASTER_ADDR"] = str(self._main_address)
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
os.environ["MASTER_PORT"] = str(self._main_port)
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
@property
def creates_processes_externally(self) -> bool:
"""LSF creates subprocesses, i.e., PyTorch Lightning does not need to spawn them."""
return True
@property
def main_address(self) -> str:
"""The main address is read from a list of hosts contained in the environment variable `LSB_HOSTS`."""
"""The main address is read from an OpenMPI host rank file in the environment variable
``LSB_DJOB_RANKFILE``."""
return self._main_address
@property
def main_port(self) -> int:
"""The main port gets calculated from the LSF job ID."""
"""The main port is calculated from the LSF job ID."""
return self._main_port
@staticmethod
def detect() -> bool:
"""Returns ``True`` if the current process was launched using the jsrun command."""
required_env_vars = {"LSB_JOBID", "LSB_HOSTS", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE"}
"""Returns ``True`` if the current process was launched using the ``jsrun`` command."""
required_env_vars = {"LSB_JOBID", "LSB_DJOB_RANKFILE", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE"}
return required_env_vars.issubset(os.environ.keys())
def world_size(self) -> int:
"""The world size is read from the environment variable `JSM_NAMESPACE_SIZE`."""
var = "JSM_NAMESPACE_SIZE"
world_size = os.environ.get(var)
"""The world size is read from the environment variable ``JSM_NAMESPACE_SIZE``."""
world_size = os.environ.get("JSM_NAMESPACE_SIZE")
if world_size is None:
raise ValueError(
f"Cannot determine world size from environment variable {var}."
" Make sure you run your executable with `jsrun`"
"Cannot determine world size. Environment variable `JSM_NAMESPACE_SIZE` not found."
"Make sure you run your executable with `jsrun`."
)
return int(world_size)
@ -90,13 +102,12 @@ class LSFEnvironment(ClusterEnvironment):
log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
def global_rank(self) -> int:
"""The world size is read from the environment variable `JSM_NAMESPACE_RANK`."""
var = "JSM_NAMESPACE_RANK"
global_rank = os.environ.get(var)
"""The world size is read from the environment variable ``JSM_NAMESPACE_RANK``."""
global_rank = os.environ.get("JSM_NAMESPACE_RANK")
if global_rank is None:
raise ValueError(
f"Cannot determine global rank from environment variable {var}."
" Make sure you run your executable with `jsrun`"
"Cannot determine global rank. Environment variable `JSM_NAMESPACE_RANK` not found."
"Make sure you run your executable with `jsrun`."
)
return int(global_rank)
@ -105,42 +116,60 @@ class LSFEnvironment(ClusterEnvironment):
def local_rank(self) -> int:
"""The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`."""
var = "JSM_NAMESPACE_LOCAL_RANK"
local_rank = os.environ.get(var)
local_rank = os.environ.get("JSM_NAMESPACE_LOCAL_RANK")
if local_rank is None:
raise ValueError(
f"Cannot determine local rank from environment variable {var}."
" Make sure you run your executable with `jsrun`"
"Cannot determine local rank. Environment variable `JSM_NAMESPACE_LOCAL_RANK` not found."
"Make sure you run your executable with `jsrun`."
)
return int(local_rank)
def node_rank(self) -> int:
"""The node rank is determined by the position of the current hostname in the list of hosts stored in the
environment variable `LSB_HOSTS`."""
"""The node rank is determined by the position of the current hostname in the OpenMPI host rank file stored
in ``LSB_DJOB_RANKFILE``."""
return self._node_rank
def _get_node_rank(self) -> int:
"""A helper method for getting the node rank.
The node rank is determined by the position of the current node in the list of hosts used in the job. This is
calculated by reading all hosts from ``LSB_DJOB_RANKFILE`` and finding this node's hostname in the list.
"""
hosts = self._read_hosts()
count: Dict[str, int] = {}
for host in hosts:
if "batch" in host or "login" in host:
continue
if host not in count:
count[host] = len(count)
return count[socket.gethostname()]
@staticmethod
def _read_hosts() -> List[str]:
hosts = os.environ.get("LSB_HOSTS")
if not hosts:
raise ValueError("Could not find hosts in environment variable LSB_HOSTS")
hosts = hosts.split()
if len(hosts) < 2:
raise ValueError(
'Cannot parse hosts from LSB_HOSTS environment variable. Expected format: "batch <rank_0_host> ..."'
)
return hosts
"""Read compute hosts that are a part of the compute job.
LSF uses the Job Step Manager (JSM) to manage job steps. Job steps are executed by the JSM from "launch" nodes.
Each job is assigned a launch node. This launch node will be the first node in the list contained in
``LSB_DJOB_RANKFILE``.
"""
var = "LSB_DJOB_RANKFILE"
rankfile = os.environ.get(var)
if rankfile is None:
raise ValueError("Did not find the environment variable `LSB_DJOB_RANKFILE`")
if not rankfile:
raise ValueError("The environment variable `LSB_DJOB_RANKFILE` is empty")
fs = get_filesystem(rankfile)
with fs.open(rankfile, "r") as f:
ret = [line.strip() for line in f]
# remove the launch node (i.e. the first node in LSB_DJOB_RANKFILE) from the list
return ret[1:]
def _get_main_address(self) -> str:
"""A helper for getting the main address.
The main address is assigned to the first node in the list of nodes used for the job.
"""
hosts = self._read_hosts()
return hosts[1]
return hosts[0]
@staticmethod
def _get_main_port() -> int:

View File

@ -41,6 +41,7 @@ from tests.helpers import BoringModel
from tests.helpers.datamodules import MNISTDataModule
from tests.helpers.runif import RunIf
from tests.loggers.test_base import CustomLogger
from tests.plugins.environments.test_lsf_environment import _make_rankfile
def test_v1_7_0_deprecated_lightning_module_summarize(tmpdir):
@ -514,8 +515,7 @@ def test_v1_7_0_cluster_environment_master_port(cls):
(TorchElasticEnvironment, "is_using_torchelastic"),
],
)
@mock.patch.dict(os.environ, {"LSB_HOSTS": "batch 10.10.10.0 10.10.10.1", "LSB_JOBID": "1234"})
def test_v1_7_0_cluster_environment_detection(cls, method_name):
def test_v1_7_0_cluster_environment_detection(cls, method_name, tmp_path):
class MyClusterEnvironment(cls):
@staticmethod
def is_using_kubeflow():
@ -529,10 +529,19 @@ def test_v1_7_0_cluster_environment_detection(cls, method_name):
def is_using_torchelastic():
pass
with pytest.deprecated_call(
match=f"MyClusterEnvironment.{method_name}` has been deprecated in v1.6 and will be removed in v1.7"
):
MyClusterEnvironment()
environ = {
"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path),
"LSB_JOBID": "1234",
"JSM_NAMESPACE_SIZE": "4",
"JSM_NAMESPACE_RANK": "3",
"JSM_NAMESPACE_LOCAL_RANK": "1",
}
with mock.patch.dict(os.environ, environ):
with mock.patch("socket.gethostname", return_value="10.10.10.2"):
with pytest.deprecated_call(
match=f"MyClusterEnvironment.{method_name}` has been deprecated in v1.6 and will be removed in v1.7"
):
MyClusterEnvironment()
def test_v1_7_0_index_batch_sampler_wrapper_batch_indices():

View File

@ -19,60 +19,84 @@ import pytest
from pytorch_lightning.plugins.environments import LSFEnvironment
@mock.patch.dict(os.environ, {"LSB_HOSTS": "batch 10.10.10.0 10.10.10.1", "LSB_JOBID": "1234"})
def test_missing_lsb_hosts():
"""Test an error when the lsb hosts list cannot be found."""
del os.environ["LSB_HOSTS"]
with pytest.raises(ValueError, match="Could not find hosts in environment variable LSB_HOSTS"):
def _make_rankfile(tmp_path):
hosts = "batch\n10.10.10.0\n10.10.10.1\n10.10.10.2\n10.10.10.3"
p = tmp_path / "lsb_djob_rankfile"
p.write_text(hosts)
return str(p)
@mock.patch.dict(os.environ, {"LSB_JOBID": "1234"})
def test_missing_lsb_djob_rankfile():
"""Test an error when the LSB_DJOB_RANKFILE cannot be found."""
with pytest.raises(ValueError, match="Did not find the environment variable `LSB_DJOB_RANKFILE`"):
LSFEnvironment()
@mock.patch.dict(os.environ, {"LSB_HOSTS": "batch 10.10.10.0 10.10.10.1", "LSB_JOBID": "1234"})
def test_missing_lsb_job_id():
@mock.patch.dict(os.environ, {"LSB_DJOB_RANKFILE": "", "LSB_JOBID": "1234"})
def test_empty_lsb_djob_rankfile():
"""Test an error when the LSB_DJOB_RANKFILE is not populated."""
with pytest.raises(ValueError, match="The environment variable `LSB_DJOB_RANKFILE` is empty"):
LSFEnvironment()
def test_missing_lsb_job_id(tmp_path):
"""Test an error when the job id cannot be found."""
del os.environ["LSB_JOBID"]
with pytest.raises(ValueError, match="Could not find job id in environment variable LSB_JOBID"):
with mock.patch.dict(os.environ, {"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path)}), pytest.raises(
ValueError, match="Could not find job id in environment variable LSB_JOBID"
):
LSFEnvironment()
@mock.patch.dict(os.environ, {"MASTER_PORT": "4321", "LSB_JOBID": "1234", "LSB_HOSTS": "batch 10.10.10.0 10.10.10.1"})
def test_manual_main_port_and_address():
def test_manual_main_port_and_address(tmp_path):
"""Test a user can set the port manually through the MASTER_PORT env variable."""
env = LSFEnvironment()
assert env.main_port == 4321
@mock.patch.dict(
os.environ,
{
"LSB_HOSTS": "batch 10.10.10.0 10.10.10.1 10.10.10.2 10.10.10.3",
environ = {
"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path),
"LSB_JOBID": "1234",
"JSM_NAMESPACE_SIZE": "4",
"JSM_NAMESPACE_RANK": "3",
"JSM_NAMESPACE_LOCAL_RANK": "1",
},
)
def test_attributes_from_environment_variables():
}
with mock.patch.dict(os.environ, environ), mock.patch("socket.gethostname", return_value="10.10.10.2"):
env = LSFEnvironment()
assert env.main_port == 10234
def test_attributes_from_environment_variables(tmp_path):
"""Test that the LSF environment takes the attributes from the environment variables."""
env = LSFEnvironment()
assert env.creates_processes_externally
assert env.main_address == "10.10.10.0"
assert env.main_port == 10234
assert env.world_size() == 4
assert env.global_rank() == 3
assert env.local_rank() == 1
env.set_global_rank(100)
assert env.global_rank() == 3
env.set_world_size(100)
assert env.world_size() == 4
assert LSFEnvironment.detect()
environ = {
"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path),
"LSB_JOBID": "1234",
"JSM_NAMESPACE_SIZE": "4",
"JSM_NAMESPACE_RANK": "3",
"JSM_NAMESPACE_LOCAL_RANK": "1",
}
with mock.patch.dict(os.environ, environ), mock.patch("socket.gethostname", return_value="10.10.10.2"):
env = LSFEnvironment()
assert env.creates_processes_externally
assert env.main_address == "10.10.10.0"
assert env.main_port == 10234
assert env.world_size() == 4
assert env.global_rank() == 3
assert env.local_rank() == 1
env.set_global_rank(100)
assert env.global_rank() == 3
env.set_world_size(100)
assert env.world_size() == 4
assert LSFEnvironment.detect()
@mock.patch("socket.gethostname", return_value="host2")
@mock.patch.dict(os.environ, {"LSB_HOSTS": "batch host0 host1 host2 host3", "LSB_JOBID": "1234"})
def test_node_rank(_):
env = LSFEnvironment()
assert env.node_rank() == 2
def test_node_rank(tmp_path):
environ = {
"LSB_DJOB_RANKFILE": _make_rankfile(tmp_path),
"LSB_JOBID": "1234",
"JSM_NAMESPACE_SIZE": "4",
"JSM_NAMESPACE_RANK": "3",
"JSM_NAMESPACE_LOCAL_RANK": "1",
}
with mock.patch.dict(os.environ, environ), mock.patch("socket.gethostname", return_value="10.10.10.2"):
env = LSFEnvironment()
assert env.node_rank() == 2
def test_detect():
@ -83,7 +107,7 @@ def test_detect():
with mock.patch.dict(
os.environ,
{
"LSB_HOSTS": "",
"LSB_DJOB_RANKFILE": "",
"LSB_JOBID": "",
"JSM_NAMESPACE_SIZE": "",
"JSM_NAMESPACE_LOCAL_RANK": "",