Fix typing in `pl.plugins.environments` (#10943)
This commit is contained in:
parent
6bfc0bbc56
commit
46f718d2ba
|
@ -47,7 +47,6 @@ module = [
|
|||
"pytorch_lightning.callbacks.finetuning",
|
||||
"pytorch_lightning.callbacks.lr_monitor",
|
||||
"pytorch_lightning.callbacks.model_checkpoint",
|
||||
"pytorch_lightning.callbacks.prediction_writer",
|
||||
"pytorch_lightning.callbacks.progress.base",
|
||||
"pytorch_lightning.callbacks.progress.progress",
|
||||
"pytorch_lightning.callbacks.progress.rich_progress",
|
||||
|
@ -70,10 +69,6 @@ module = [
|
|||
"pytorch_lightning.loggers.test_tube",
|
||||
"pytorch_lightning.loggers.wandb",
|
||||
"pytorch_lightning.loops.epoch.training_epoch_loop",
|
||||
"pytorch_lightning.plugins.environments.lightning_environment",
|
||||
"pytorch_lightning.plugins.environments.lsf_environment",
|
||||
"pytorch_lightning.plugins.environments.slurm_environment",
|
||||
"pytorch_lightning.plugins.environments.torchelastic_environment",
|
||||
"pytorch_lightning.plugins.training_type.ddp",
|
||||
"pytorch_lightning.plugins.training_type.ddp2",
|
||||
"pytorch_lightning.plugins.training_type.ddp_spawn",
|
||||
|
|
|
@ -34,9 +34,9 @@ class LightningEnvironment(ClusterEnvironment):
|
|||
training as it provides a convenient way to launch the training script.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._main_port = None
|
||||
self._main_port: int = -1
|
||||
self._global_rank: int = 0
|
||||
self._world_size: int = 1
|
||||
|
||||
|
@ -55,9 +55,9 @@ class LightningEnvironment(ClusterEnvironment):
|
|||
|
||||
@property
|
||||
def main_port(self) -> int:
|
||||
if self._main_port is None:
|
||||
self._main_port = os.environ.get("MASTER_PORT", find_free_network_port())
|
||||
return int(self._main_port)
|
||||
if self._main_port == -1:
|
||||
self._main_port = int(os.environ.get("MASTER_PORT", find_free_network_port()))
|
||||
return self._main_port
|
||||
|
||||
@staticmethod
|
||||
def detect() -> bool:
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
import os
|
||||
import socket
|
||||
from typing import Dict, List
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.plugins.environments import ClusterEnvironment
|
||||
|
@ -41,7 +42,7 @@ class LSFEnvironment(ClusterEnvironment):
|
|||
The world size for the task. This environment variable is set by jsrun
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# TODO: remove in 1.7
|
||||
if hasattr(self, "is_using_lsf") and callable(self.is_using_lsf):
|
||||
|
@ -74,7 +75,7 @@ class LSFEnvironment(ClusterEnvironment):
|
|||
required_env_vars = {"LSB_JOBID", "LSB_HOSTS", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE"}
|
||||
return required_env_vars.issubset(os.environ.keys())
|
||||
|
||||
def world_size(self):
|
||||
def world_size(self) -> int:
|
||||
"""The world size is read from the environment variable `JSM_NAMESPACE_SIZE`."""
|
||||
var = "JSM_NAMESPACE_SIZE"
|
||||
world_size = os.environ.get(var)
|
||||
|
@ -88,7 +89,7 @@ class LSFEnvironment(ClusterEnvironment):
|
|||
def set_world_size(self, size: int) -> None:
|
||||
log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
|
||||
|
||||
def global_rank(self):
|
||||
def global_rank(self) -> int:
|
||||
"""The world size is read from the environment variable `JSM_NAMESPACE_RANK`."""
|
||||
var = "JSM_NAMESPACE_RANK"
|
||||
global_rank = os.environ.get(var)
|
||||
|
@ -102,7 +103,7 @@ class LSFEnvironment(ClusterEnvironment):
|
|||
def set_global_rank(self, rank: int) -> None:
|
||||
log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
|
||||
|
||||
def local_rank(self):
|
||||
def local_rank(self) -> int:
|
||||
"""The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`."""
|
||||
var = "JSM_NAMESPACE_LOCAL_RANK"
|
||||
local_rank = os.environ.get(var)
|
||||
|
@ -113,11 +114,11 @@ class LSFEnvironment(ClusterEnvironment):
|
|||
)
|
||||
return int(local_rank)
|
||||
|
||||
def node_rank(self):
|
||||
def node_rank(self) -> int:
|
||||
"""The node rank is determined by the position of the current hostname in the list of hosts stored in the
|
||||
environment variable `LSB_HOSTS`."""
|
||||
hosts = self._read_hosts()
|
||||
count = {}
|
||||
count: Dict[str, int] = {}
|
||||
for host in hosts:
|
||||
if "batch" in host or "login" in host:
|
||||
continue
|
||||
|
@ -126,7 +127,7 @@ class LSFEnvironment(ClusterEnvironment):
|
|||
return count[socket.gethostname()]
|
||||
|
||||
@staticmethod
|
||||
def _read_hosts():
|
||||
def _read_hosts() -> List[str]:
|
||||
hosts = os.environ.get("LSB_HOSTS")
|
||||
if not hosts:
|
||||
raise ValueError("Could not find hosts in environment variable LSB_HOSTS")
|
||||
|
@ -148,15 +149,13 @@ class LSFEnvironment(ClusterEnvironment):
|
|||
Uses the LSF job ID so all ranks can compute the main port.
|
||||
"""
|
||||
# check for user-specified main port
|
||||
port = os.environ.get("MASTER_PORT")
|
||||
if not port:
|
||||
jobid = os.environ.get("LSB_JOBID")
|
||||
if not jobid:
|
||||
raise ValueError("Could not find job id in environment variable LSB_JOBID")
|
||||
port = int(jobid)
|
||||
if "MASTER_PORT" in os.environ:
|
||||
log.debug(f"Using externally specified main port: {os.environ['MASTER_PORT']}")
|
||||
return int(os.environ["MASTER_PORT"])
|
||||
if "LSB_JOBID" in os.environ:
|
||||
port = int(os.environ["LSB_JOBID"])
|
||||
# all ports should be in the 10k+ range
|
||||
port = int(port) % 1000 + 10000
|
||||
port = port % 1000 + 10000
|
||||
log.debug(f"calculated LSF main port: {port}")
|
||||
else:
|
||||
log.debug(f"using externally specified main port: {port}")
|
||||
return int(port)
|
||||
return port
|
||||
raise ValueError("Could not find job id in environment variable LSB_JOBID")
|
||||
|
|
|
@ -58,10 +58,10 @@ class SLURMEnvironment(ClusterEnvironment):
|
|||
# SLURM JOB = PORT number
|
||||
# -----------------------
|
||||
# this way every process knows what port to use
|
||||
default_port = os.environ.get("SLURM_JOB_ID")
|
||||
if default_port:
|
||||
job_id = os.environ.get("SLURM_JOB_ID")
|
||||
if job_id is not None:
|
||||
# use the last 4 numbers in the job id as the id
|
||||
default_port = default_port[-4:]
|
||||
default_port = job_id[-4:]
|
||||
# all ports should be in the 10k+ range
|
||||
default_port = int(default_port) + 15000
|
||||
else:
|
||||
|
@ -72,13 +72,12 @@ class SLURMEnvironment(ClusterEnvironment):
|
|||
# -----------------------
|
||||
# in case the user passed it in
|
||||
if "MASTER_PORT" in os.environ:
|
||||
default_port = os.environ["MASTER_PORT"]
|
||||
default_port = int(os.environ["MASTER_PORT"])
|
||||
else:
|
||||
os.environ["MASTER_PORT"] = str(default_port)
|
||||
|
||||
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
|
||||
|
||||
return int(default_port)
|
||||
return default_port
|
||||
|
||||
@staticmethod
|
||||
def detect() -> bool:
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
|
||||
|
@ -45,8 +44,7 @@ class TorchElasticEnvironment(ClusterEnvironment):
|
|||
rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost")
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
|
||||
main_address = os.environ.get("MASTER_ADDR")
|
||||
return main_address
|
||||
return os.environ["MASTER_ADDR"]
|
||||
|
||||
@property
|
||||
def main_port(self) -> int:
|
||||
|
@ -55,8 +53,7 @@ class TorchElasticEnvironment(ClusterEnvironment):
|
|||
os.environ["MASTER_PORT"] = "12910"
|
||||
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
|
||||
|
||||
port = int(os.environ.get("MASTER_PORT"))
|
||||
return port
|
||||
return int(os.environ["MASTER_PORT"])
|
||||
|
||||
@staticmethod
|
||||
def detect() -> bool:
|
||||
|
@ -64,9 +61,8 @@ class TorchElasticEnvironment(ClusterEnvironment):
|
|||
required_env_vars = {"RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE"}
|
||||
return required_env_vars.issubset(os.environ.keys())
|
||||
|
||||
def world_size(self) -> Optional[int]:
|
||||
world_size = os.environ.get("WORLD_SIZE")
|
||||
return int(world_size) if world_size is not None else world_size
|
||||
def world_size(self) -> int:
|
||||
return int(os.environ["WORLD_SIZE"])
|
||||
|
||||
def set_world_size(self, size: int) -> None:
|
||||
log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
|
||||
|
|
|
@ -27,7 +27,9 @@ def test_default_attributes():
|
|||
assert env.creates_processes_externally
|
||||
assert env.main_address == "127.0.0.1"
|
||||
assert env.main_port == 12910
|
||||
assert env.world_size() is None
|
||||
with pytest.raises(KeyError):
|
||||
# world size is required to be passed as env variable
|
||||
env.world_size()
|
||||
with pytest.raises(KeyError):
|
||||
# local rank is required to be passed as env variable
|
||||
env.local_rank()
|
||||
|
|
Loading…
Reference in New Issue