Fix typing in `pl.plugins.environments` (#10943)

This commit is contained in:
Adrian Wälchli 2021-12-07 03:14:02 +01:00 committed by GitHub
parent 6bfc0bbc56
commit 46f718d2ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 33 additions and 42 deletions

View File

@ -47,7 +47,6 @@ module = [
"pytorch_lightning.callbacks.finetuning",
"pytorch_lightning.callbacks.lr_monitor",
"pytorch_lightning.callbacks.model_checkpoint",
"pytorch_lightning.callbacks.prediction_writer",
"pytorch_lightning.callbacks.progress.base",
"pytorch_lightning.callbacks.progress.progress",
"pytorch_lightning.callbacks.progress.rich_progress",
@ -70,10 +69,6 @@ module = [
"pytorch_lightning.loggers.test_tube",
"pytorch_lightning.loggers.wandb",
"pytorch_lightning.loops.epoch.training_epoch_loop",
"pytorch_lightning.plugins.environments.lightning_environment",
"pytorch_lightning.plugins.environments.lsf_environment",
"pytorch_lightning.plugins.environments.slurm_environment",
"pytorch_lightning.plugins.environments.torchelastic_environment",
"pytorch_lightning.plugins.training_type.ddp",
"pytorch_lightning.plugins.training_type.ddp2",
"pytorch_lightning.plugins.training_type.ddp_spawn",

View File

@ -34,9 +34,9 @@ class LightningEnvironment(ClusterEnvironment):
training as it provides a convenient way to launch the training script.
"""
def __init__(self):
def __init__(self) -> None:
super().__init__()
self._main_port = None
self._main_port: int = -1
self._global_rank: int = 0
self._world_size: int = 1
@ -55,9 +55,9 @@ class LightningEnvironment(ClusterEnvironment):
@property
def main_port(self) -> int:
if self._main_port is None:
self._main_port = os.environ.get("MASTER_PORT", find_free_network_port())
return int(self._main_port)
if self._main_port == -1:
self._main_port = int(os.environ.get("MASTER_PORT", find_free_network_port()))
return self._main_port
@staticmethod
def detect() -> bool:

View File

@ -14,6 +14,7 @@
import os
import socket
from typing import Dict, List
from pytorch_lightning import _logger as log
from pytorch_lightning.plugins.environments import ClusterEnvironment
@ -41,7 +42,7 @@ class LSFEnvironment(ClusterEnvironment):
The world size for the task. This environment variable is set by jsrun
"""
def __init__(self):
def __init__(self) -> None:
super().__init__()
# TODO: remove in 1.7
if hasattr(self, "is_using_lsf") and callable(self.is_using_lsf):
@ -74,7 +75,7 @@ class LSFEnvironment(ClusterEnvironment):
required_env_vars = {"LSB_JOBID", "LSB_HOSTS", "JSM_NAMESPACE_LOCAL_RANK", "JSM_NAMESPACE_SIZE"}
return required_env_vars.issubset(os.environ.keys())
def world_size(self):
def world_size(self) -> int:
"""The world size is read from the environment variable `JSM_NAMESPACE_SIZE`."""
var = "JSM_NAMESPACE_SIZE"
world_size = os.environ.get(var)
@ -88,7 +89,7 @@ class LSFEnvironment(ClusterEnvironment):
def set_world_size(self, size: int) -> None:
log.debug("LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
def global_rank(self):
def global_rank(self) -> int:
"""The world size is read from the environment variable `JSM_NAMESPACE_RANK`."""
var = "JSM_NAMESPACE_RANK"
global_rank = os.environ.get(var)
@ -102,7 +103,7 @@ class LSFEnvironment(ClusterEnvironment):
def set_global_rank(self, rank: int) -> None:
log.debug("LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
def local_rank(self):
def local_rank(self) -> int:
"""The local rank is read from the environment variable `JSM_NAMESPACE_LOCAL_RANK`."""
var = "JSM_NAMESPACE_LOCAL_RANK"
local_rank = os.environ.get(var)
@ -113,11 +114,11 @@ class LSFEnvironment(ClusterEnvironment):
)
return int(local_rank)
def node_rank(self):
def node_rank(self) -> int:
"""The node rank is determined by the position of the current hostname in the list of hosts stored in the
environment variable `LSB_HOSTS`."""
hosts = self._read_hosts()
count = {}
count: Dict[str, int] = {}
for host in hosts:
if "batch" in host or "login" in host:
continue
@ -126,7 +127,7 @@ class LSFEnvironment(ClusterEnvironment):
return count[socket.gethostname()]
@staticmethod
def _read_hosts():
def _read_hosts() -> List[str]:
hosts = os.environ.get("LSB_HOSTS")
if not hosts:
raise ValueError("Could not find hosts in environment variable LSB_HOSTS")
@ -148,15 +149,13 @@ class LSFEnvironment(ClusterEnvironment):
Uses the LSF job ID so all ranks can compute the main port.
"""
# check for user-specified main port
port = os.environ.get("MASTER_PORT")
if not port:
jobid = os.environ.get("LSB_JOBID")
if not jobid:
raise ValueError("Could not find job id in environment variable LSB_JOBID")
port = int(jobid)
if "MASTER_PORT" in os.environ:
log.debug(f"Using externally specified main port: {os.environ['MASTER_PORT']}")
return int(os.environ["MASTER_PORT"])
if "LSB_JOBID" in os.environ:
port = int(os.environ["LSB_JOBID"])
# all ports should be in the 10k+ range
port = int(port) % 1000 + 10000
port = port % 1000 + 10000
log.debug(f"calculated LSF main port: {port}")
else:
log.debug(f"using externally specified main port: {port}")
return int(port)
return port
raise ValueError("Could not find job id in environment variable LSB_JOBID")

View File

@ -58,10 +58,10 @@ class SLURMEnvironment(ClusterEnvironment):
# SLURM JOB = PORT number
# -----------------------
# this way every process knows what port to use
default_port = os.environ.get("SLURM_JOB_ID")
if default_port:
job_id = os.environ.get("SLURM_JOB_ID")
if job_id is not None:
# use the last 4 numbers in the job id as the id
default_port = default_port[-4:]
default_port = job_id[-4:]
# all ports should be in the 10k+ range
default_port = int(default_port) + 15000
else:
@ -72,13 +72,12 @@ class SLURMEnvironment(ClusterEnvironment):
# -----------------------
# in case the user passed it in
if "MASTER_PORT" in os.environ:
default_port = os.environ["MASTER_PORT"]
default_port = int(os.environ["MASTER_PORT"])
else:
os.environ["MASTER_PORT"] = str(default_port)
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
return int(default_port)
return default_port
@staticmethod
def detect() -> bool:

View File

@ -14,7 +14,6 @@
import logging
import os
from typing import Optional
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
@ -45,8 +44,7 @@ class TorchElasticEnvironment(ClusterEnvironment):
rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost")
os.environ["MASTER_ADDR"] = "127.0.0.1"
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
main_address = os.environ.get("MASTER_ADDR")
return main_address
return os.environ["MASTER_ADDR"]
@property
def main_port(self) -> int:
@ -55,8 +53,7 @@ class TorchElasticEnvironment(ClusterEnvironment):
os.environ["MASTER_PORT"] = "12910"
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
port = int(os.environ.get("MASTER_PORT"))
return port
return int(os.environ["MASTER_PORT"])
@staticmethod
def detect() -> bool:
@ -64,9 +61,8 @@ class TorchElasticEnvironment(ClusterEnvironment):
required_env_vars = {"RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE"}
return required_env_vars.issubset(os.environ.keys())
def world_size(self) -> Optional[int]:
world_size = os.environ.get("WORLD_SIZE")
return int(world_size) if world_size is not None else world_size
def world_size(self) -> int:
return int(os.environ["WORLD_SIZE"])
def set_world_size(self, size: int) -> None:
log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

View File

@ -27,7 +27,9 @@ def test_default_attributes():
assert env.creates_processes_externally
assert env.main_address == "127.0.0.1"
assert env.main_port == 12910
assert env.world_size() is None
with pytest.raises(KeyError):
# world size is required to be passed as env variable
env.world_size()
with pytest.raises(KeyError):
# local rank is required to be passed as env variable
env.local_rank()