[ddp] Support multi-node distributed execution under torchelastic (#1811)
The changes are quite local and limited in nature -- viz., checking for some indicator environment variables. We check for (SLURM_LOCALID, NODE_RANK, GROUP_RANK) in order. If multiple are found set, a warning is logged. This patch also fixes a minor bug with comparing the `WORLD_SIZE` environment variable. This can be a string type.
This commit is contained in:
parent
b1d9656470
commit
aefc5314bc
|
@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
- Added support for Pytorch elastic distributed launch environment ([#1811](https://github.com/PyTorchLightning/pytorch-lightning/pull/1811))
|
||||
|
||||
- Added callback for logging learning rates ([#1498](https://github.com/PyTorchLightning/pytorch-lightning/pull/1498))
|
||||
|
||||
- Added transfer learning example (for a binary classification task in computer vision) ([#1564](https://github.com/PyTorchLightning/pytorch-lightning/pull/1564))
|
||||
|
|
|
@ -944,11 +944,12 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
os.environ['MASTER_PORT'] = '12910'
|
||||
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
|
||||
|
||||
if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != world_size:
|
||||
log.warning("WORLD_SIZE environment variable is not equal to the computed "
|
||||
"world size. Ignored.")
|
||||
if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) != world_size:
|
||||
log.warning(f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) "
|
||||
f"is not equal to the computed world size ({world_size}). Ignored.")
|
||||
|
||||
torch_backend = "nccl" if self.trainer.on_gpu else "gloo"
|
||||
log.info(f"initializing proc_rank {proc_rank} world {world_size}")
|
||||
torch_distrib.init_process_group(torch_backend, rank=proc_rank, world_size=world_size)
|
||||
|
||||
def configure_apex(
|
||||
|
|
|
@ -279,6 +279,25 @@ class TrainerDDPMixin(ABC):
|
|||
if self.is_slurm_managing_tasks:
|
||||
log.info('Multi-processing is handled by Slurm.')
|
||||
|
||||
def determine_ddp_node_rank(self):
|
||||
if self.is_slurm_managing_tasks:
|
||||
return int(os.environ['SLURM_NODEID'])
|
||||
|
||||
# torchelastic uses the envvar GROUP_RANK, whereas other systems(?) use NODE_RANK.
|
||||
# otherwise use given node rank or default to node rank 0
|
||||
env_vars = ['NODE_RANK', 'GROUP_RANK']
|
||||
node_ids = [(k, os.environ.get(k, None)) for k in env_vars]
|
||||
node_ids = [(k, v) for k, v in node_ids if v is not None]
|
||||
if len(node_ids) == 0:
|
||||
log.warning("No environment variable for node rank defined. Set as 0.")
|
||||
return 0
|
||||
if len(node_ids) > 1:
|
||||
log.warning(f"Multiple environment variables ({node_ids}) defined for node rank. "
|
||||
f"Using the first one.")
|
||||
k, rank = node_ids.pop()
|
||||
log.info(f"Using environment variable {k} for node rank ({rank}).")
|
||||
return int(rank)
|
||||
|
||||
def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
|
||||
if data_parallel_device_ids is None:
|
||||
return
|
||||
|
@ -305,15 +324,6 @@ class TrainerDDPMixin(ABC):
|
|||
:param cluster_obj:
|
||||
:return:
|
||||
"""
|
||||
# node rank using relative slurm id if under slurm management
|
||||
# otherwise use given node rank or default to node rank 0
|
||||
try:
|
||||
node_id = os.environ['SLURM_NODEID'] if self.is_slurm_managing_tasks else os.environ['NODE_RANK']
|
||||
self.node_rank = int(node_id)
|
||||
except KeyError:
|
||||
log.warning("SLURM_NODEID or NODE_RANK environment variable is not defined. Set as 0.")
|
||||
self.node_rank = 0
|
||||
|
||||
# show progressbar only on progress_rank 0
|
||||
if (self.node_rank != 0 or process_idx != 0) and self.progress_bar_callback is not None:
|
||||
self.progress_bar_callback.disable()
|
||||
|
|
|
@ -483,8 +483,8 @@ class Trainer(
|
|||
# init flags for SLURM+ddp to work
|
||||
self.proc_rank = 0
|
||||
self.world_size = 1
|
||||
self.node_rank = 0
|
||||
self.configure_slurm_ddp(self.num_nodes)
|
||||
self.node_rank = self.determine_ddp_node_rank()
|
||||
|
||||
# nvidia setup
|
||||
self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
|
||||
|
@ -796,11 +796,14 @@ class Trainer(
|
|||
if self.use_ddp2:
|
||||
task = int(os.environ['SLURM_LOCALID'])
|
||||
self.ddp_train(task, model)
|
||||
|
||||
elif self.use_ddp:
|
||||
if self.is_slurm_managing_tasks:
|
||||
task = int(os.environ['SLURM_LOCALID'])
|
||||
self.ddp_train(task, model)
|
||||
# torchelastic
|
||||
elif 'WORLD_SIZE' in os.environ and 'GROUP_RANK' in os.environ:
|
||||
task = int(os.environ['LOCAL_RANK'])
|
||||
self.ddp_train(task, model)
|
||||
else:
|
||||
self.__set_random_port()
|
||||
# track for predict
|
||||
|
|
Loading…
Reference in New Issue