diff --git a/benchmarks/test_parity.py b/benchmarks/test_parity.py index d5a5cbf340..cbef2f61cb 100644 --- a/benchmarks/test_parity.py +++ b/benchmarks/test_parity.py @@ -11,7 +11,7 @@ from tests.base.models import ParityModuleRNN, ParityModuleMNIST @pytest.mark.parametrize('cls_model,max_diff', [ (ParityModuleRNN, 0.05), - (ParityModuleMNIST, 0.55) + (ParityModuleMNIST, 0.57) ]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_pytorch_parity(tmpdir, cls_model, max_diff): diff --git a/docs/source/lightning_module.rst b/docs/source/lightning_module.rst index fe00120e01..a4e9dbe183 100644 --- a/docs/source/lightning_module.rst +++ b/docs/source/lightning_module.rst @@ -1048,12 +1048,6 @@ get_progress_bar_dict .. autofunction:: pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict :noindex: -init_ddp_connection -~~~~~~~~~~~~~~~~~~~ - -.. autofunction:: pytorch_lightning.core.lightning.LightningModule.init_ddp_connection - :noindex: - tbptt_split_batch ~~~~~~~~~~~~~~~~~ diff --git a/pytorch_lightning/accelerators/base_backend.py b/pytorch_lightning/accelerators/base_backend.py index 910dc235af..755775bc7e 100644 --- a/pytorch_lightning/accelerators/base_backend.py +++ b/pytorch_lightning/accelerators/base_backend.py @@ -1,3 +1,4 @@ +import os import math from enum import Enum from typing import Any @@ -8,6 +9,8 @@ from pytorch_lightning.utilities import AMPType, rank_zero_warn from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict +import torch.distributed as torch_distrib +from pytorch_lightning import _logger as log try: from apex import amp @@ -185,6 +188,58 @@ class Accelerator(object): self.trainer.lr_schedulers = lr_schedulers self.trainer.optimizer_frequencies = optimizer_frequencies + def init_ddp_connection( + self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True + ) -> None: + if is_slurm_managing_tasks: + self.trainer.slurm_connector.connect_ddp(global_rank, world_size) + else: + self.connect_torchelastic(global_rank, world_size) + + def connect_torchelastic( + self, global_rank: int, world_size: int + ) -> None: + """ + Override to define your custom way of setting up a distributed environment. + + Lightning's implementation uses env:// init by default and sets the first node as root + for SLURM managed cluster. + + Args: + global_rank: The global process idx. + world_size: Number of GPUs being use across all nodes. (num_nodes * num_gpus). + """ + + if "MASTER_ADDR" not in os.environ: + rank_zero_warn( + "MASTER_ADDR environment variable is not defined. Set as localhost" + ) + os.environ["MASTER_ADDR"] = "127.0.0.1" + log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") + + if "MASTER_PORT" not in os.environ: + rank_zero_warn( + "MASTER_PORT environment variable is not defined. Set as 12910" + ) + os.environ["MASTER_PORT"] = "12910" + log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") + + if "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) != world_size: + rank_zero_warn( + f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) " + f"is not equal to the computed world size ({world_size}). Ignored." + ) + + torch_backend = "nccl" if self.trainer.on_gpu else "gloo" + + if not torch.distributed.is_initialized(): + log.info( + f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}" + ) + torch_distrib.init_process_group( + torch_backend, rank=global_rank, world_size=world_size + ) + # TODO: allow user to compare with string even internaly we shall use these Enum to prevent typos... class BackendType(Enum): diff --git a/pytorch_lightning/accelerators/ddp2_backend.py b/pytorch_lightning/accelerators/ddp2_backend.py index 90ed0ff54c..2d1c106654 100644 --- a/pytorch_lightning/accelerators/ddp2_backend.py +++ b/pytorch_lightning/accelerators/ddp2_backend.py @@ -137,7 +137,7 @@ class DDP2Backend(Accelerator): # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer - model.init_ddp_connection( + self.init_ddp_connection( self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks diff --git a/pytorch_lightning/accelerators/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py index f71b183d71..0adb1758cc 100644 --- a/pytorch_lightning/accelerators/ddp_backend.py +++ b/pytorch_lightning/accelerators/ddp_backend.py @@ -222,7 +222,7 @@ class DDPBackend(Accelerator): # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer - model.init_ddp_connection( + self.init_ddp_connection( self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py index bc494147b9..5762b06223 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py @@ -90,7 +90,7 @@ class DDPCPUSpawnBackend(Accelerator): # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer - model.init_ddp_connection( + self.init_ddp_connection( self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks diff --git a/pytorch_lightning/accelerators/ddp_slurm_backend.py b/pytorch_lightning/accelerators/ddp_slurm_backend.py index 5c39a74615..875c10c569 100644 --- a/pytorch_lightning/accelerators/ddp_slurm_backend.py +++ b/pytorch_lightning/accelerators/ddp_slurm_backend.py @@ -128,7 +128,7 @@ class DDPSLURMBackend(Accelerator): # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer - model.init_ddp_connection( + self.init_ddp_connection( self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks diff --git a/pytorch_lightning/accelerators/ddp_spawn_backend.py b/pytorch_lightning/accelerators/ddp_spawn_backend.py index 0e101cf727..6b6b242c8b 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_spawn_backend.py @@ -103,7 +103,7 @@ class DDPSpawnBackend(Accelerator): # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer - model.init_ddp_connection( + self.init_ddp_connection( self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks diff --git a/pytorch_lightning/accelerators/ddp_torchelastic_backend.py b/pytorch_lightning/accelerators/ddp_torchelastic_backend.py index c606667be7..0f5b6f36d5 100644 --- a/pytorch_lightning/accelerators/ddp_torchelastic_backend.py +++ b/pytorch_lightning/accelerators/ddp_torchelastic_backend.py @@ -124,7 +124,7 @@ class DDPTorchElasticBackend(Accelerator): # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self.trainer - model.init_ddp_connection( + self.init_ddp_connection( self.trainer.global_rank, self.trainer.world_size, self.trainer.is_slurm_managing_tasks diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index dc77ebc8cc..c55e78aac5 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -22,7 +22,6 @@ from argparse import Namespace from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch -import torch.distributed as torch_distrib from pytorch_lightning import _logger as log from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks @@ -956,87 +955,6 @@ class LightningModule( ) return model - def _init_slurm_connection(self) -> None: - """""" - """ - Sets up environment variables necessary for pytorch distributed communications - based on slurm environment. - """ - # use slurm job id for the port number - # guarantees unique ports across jobs from same grid search - try: - # use the last 4 numbers in the job id as the id - default_port = os.environ["SLURM_JOB_ID"] - default_port = default_port[-4:] - - # all ports should be in the 10k+ range - default_port = int(default_port) + 15000 - - except Exception: - default_port = 12910 - - # if user gave a port number, use that one instead - try: - default_port = os.environ["MASTER_PORT"] - except Exception: - os.environ["MASTER_PORT"] = str(default_port) - - # figure out the root node addr - try: - root_node = os.environ["SLURM_NODELIST"].split(" ")[0] - except Exception: - root_node = "127.0.0.1" - - root_node = self.trainer.slurm_connector.resolve_root_node_address(root_node) - os.environ["MASTER_ADDR"] = root_node - - def init_ddp_connection( - self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True - ) -> None: - """ - Override to define your custom way of setting up a distributed environment. - - Lightning's implementation uses env:// init by default and sets the first node as root - for SLURM managed cluster. - - Args: - global_rank: The global process idx. - world_size: Number of GPUs being use across all nodes. (num_nodes * num_gpus). - is_slurm_managing_tasks: is cluster managed by SLURM. - """ - if is_slurm_managing_tasks: - self._init_slurm_connection() - - if "MASTER_ADDR" not in os.environ: - rank_zero_warn( - "MASTER_ADDR environment variable is not defined. Set as localhost" - ) - os.environ["MASTER_ADDR"] = "127.0.0.1" - log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") - - if "MASTER_PORT" not in os.environ: - rank_zero_warn( - "MASTER_PORT environment variable is not defined. Set as 12910" - ) - os.environ["MASTER_PORT"] = "12910" - log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") - - if "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) != world_size: - rank_zero_warn( - f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) " - f"is not equal to the computed world size ({world_size}). Ignored." - ) - - torch_backend = "nccl" if self.trainer.on_gpu else "gloo" - - if not torch.distributed.is_initialized(): - log.info( - f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}" - ) - torch_distrib.init_process_group( - torch_backend, rank=global_rank, world_size=world_size - ) - def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule": """ Add global batchnorm for a model spread across multiple GPUs and nodes. @@ -1089,10 +1007,8 @@ class LightningModule( return model, optimizers def configure_optimizers( - self, - ) -> Optional[ - Union[Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]] - ]: + self, + ): r""" Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. diff --git a/pytorch_lightning/trainer/connectors/slurm_connector.py b/pytorch_lightning/trainer/connectors/slurm_connector.py index 1fb29d267f..14f7c95e39 100644 --- a/pytorch_lightning/trainer/connectors/slurm_connector.py +++ b/pytorch_lightning/trainer/connectors/slurm_connector.py @@ -4,6 +4,8 @@ import signal from subprocess import call from pytorch_lightning import _logger as log from pytorch_lightning.utilities.distributed import rank_zero_info +import torch.distributed as torch_distrib +import torch class SLURMConnector: @@ -101,3 +103,49 @@ class SLURMConnector: def term_handler(self, signum, frame): # save log.info("bypassing sigterm") + + def connect_ddp(self, global_rank: int, world_size: int) -> None: + """""" + """ + Sets up environment variables necessary for pytorch distributed communications + based on slurm environment. + """ + # use slurm job id for the port number + # guarantees unique ports across jobs from same grid search + try: + # use the last 4 numbers in the job id as the id + default_port = os.environ["SLURM_JOB_ID"] + default_port = default_port[-4:] + + # all ports should be in the 10k+ range + default_port = int(default_port) + 15000 + + except Exception: + default_port = 12910 + + # if user gave a port number, use that one instead + try: + default_port = os.environ["MASTER_PORT"] + except Exception: + os.environ["MASTER_PORT"] = str(default_port) + log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}") + + # figure out the root node addr + try: + root_node = os.environ["SLURM_NODELIST"].split(" ")[0] + except Exception: + root_node = "127.0.0.1" + + root_node = self.trainer.slurm_connector.resolve_root_node_address(root_node) + os.environ["MASTER_ADDR"] = root_node + log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}") + + torch_backend = "nccl" if self.trainer.on_gpu else "gloo" + + if not torch.distributed.is_initialized(): + log.info( + f"initializing ddp (SLURM): GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}" + ) + torch_distrib.init_process_group( + torch_backend, rank=global_rank, world_size=world_size + )