[ddp] Support multi-node distributed execution under torchelastic (#1811)

The changes are quite local and limited in nature -- viz., checking for
some indicator environment variables. We check for (SLURM_LOCALID,
NODE_RANK, GROUP_RANK) in order. If multiple are found set, a warning is
logged.

This patch also fixes a minor bug with comparing the `WORLD_SIZE`
environment variable. This can be a string type.
This commit is contained in:
Ashwin Bharambe 2020-05-13 11:06:59 -07:00 committed by GitHub
parent b1d9656470
commit aefc5314bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 30 additions and 14 deletions

View File

@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added support for Pytorch elastic distributed launch environment ([#1811](https://github.com/PyTorchLightning/pytorch-lightning/pull/1811))
- Added callback for logging learning rates ([#1498](https://github.com/PyTorchLightning/pytorch-lightning/pull/1498))
- Added transfer learning example (for a binary classification task in computer vision) ([#1564](https://github.com/PyTorchLightning/pytorch-lightning/pull/1564))

View File

@ -944,11 +944,12 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
os.environ['MASTER_PORT'] = '12910'
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != world_size:
log.warning("WORLD_SIZE environment variable is not equal to the computed "
"world size. Ignored.")
if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) != world_size:
log.warning(f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) "
f"is not equal to the computed world size ({world_size}). Ignored.")
torch_backend = "nccl" if self.trainer.on_gpu else "gloo"
log.info(f"initializing proc_rank {proc_rank} world {world_size}")
torch_distrib.init_process_group(torch_backend, rank=proc_rank, world_size=world_size)
def configure_apex(

View File

@ -279,6 +279,25 @@ class TrainerDDPMixin(ABC):
if self.is_slurm_managing_tasks:
log.info('Multi-processing is handled by Slurm.')
def determine_ddp_node_rank(self):
if self.is_slurm_managing_tasks:
return int(os.environ['SLURM_NODEID'])
# torchelastic uses the envvar GROUP_RANK, whereas other systems(?) use NODE_RANK.
# otherwise use given node rank or default to node rank 0
env_vars = ['NODE_RANK', 'GROUP_RANK']
node_ids = [(k, os.environ.get(k, None)) for k in env_vars]
node_ids = [(k, v) for k, v in node_ids if v is not None]
if len(node_ids) == 0:
log.warning("No environment variable for node rank defined. Set as 0.")
return 0
if len(node_ids) > 1:
log.warning(f"Multiple environment variables ({node_ids}) defined for node rank. "
f"Using the first one.")
k, rank = node_ids.pop()
log.info(f"Using environment variable {k} for node rank ({rank}).")
return int(rank)
def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
if data_parallel_device_ids is None:
return
@ -305,15 +324,6 @@ class TrainerDDPMixin(ABC):
:param cluster_obj:
:return:
"""
# node rank using relative slurm id if under slurm management
# otherwise use given node rank or default to node rank 0
try:
node_id = os.environ['SLURM_NODEID'] if self.is_slurm_managing_tasks else os.environ['NODE_RANK']
self.node_rank = int(node_id)
except KeyError:
log.warning("SLURM_NODEID or NODE_RANK environment variable is not defined. Set as 0.")
self.node_rank = 0
# show progressbar only on progress_rank 0
if (self.node_rank != 0 or process_idx != 0) and self.progress_bar_callback is not None:
self.progress_bar_callback.disable()

View File

@ -483,8 +483,8 @@ class Trainer(
# init flags for SLURM+ddp to work
self.proc_rank = 0
self.world_size = 1
self.node_rank = 0
self.configure_slurm_ddp(self.num_nodes)
self.node_rank = self.determine_ddp_node_rank()
# nvidia setup
self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
@ -796,11 +796,14 @@ class Trainer(
if self.use_ddp2:
task = int(os.environ['SLURM_LOCALID'])
self.ddp_train(task, model)
elif self.use_ddp:
if self.is_slurm_managing_tasks:
task = int(os.environ['SLURM_LOCALID'])
self.ddp_train(task, model)
# torchelastic
elif 'WORLD_SIZE' in os.environ and 'GROUP_RANK' in os.environ:
task = int(os.environ['LOCAL_RANK'])
self.ddp_train(task, model)
else:
self.__set_random_port()
# track for predict