parent
9942f3ebdf
commit
afa43837a4
|
@ -21,9 +21,17 @@ from typing import Optional
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
import torch.distributed as dist
|
||||
|
||||
from pytorch_lightning.utilities.distributed import find_free_network_port
|
||||
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase
|
||||
from pytorch_lightning.accelerators.base_backend import Accelerator
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
from pytorch_lightning.distributed.dist import LightningDistributed
|
||||
|
||||
|
||||
try:
|
||||
from hydra.utils import to_absolute_path, get_original_cwd
|
||||
|
@ -34,13 +42,14 @@ else:
|
|||
HYDRA_AVAILABLE = True
|
||||
|
||||
|
||||
class DDPBackend(DDPBase):
|
||||
class DDPBackend(Accelerator):
|
||||
|
||||
def __init__(self, trainer, mode: str = 'ddp'):
|
||||
super().__init__(trainer)
|
||||
self.task_idx = None
|
||||
self._has_spawned_children = False
|
||||
self.mode = mode
|
||||
self.dist = LightningDistributed()
|
||||
|
||||
def setup(self, model):
|
||||
if self.mode == 'ddp':
|
||||
|
@ -130,11 +139,11 @@ class DDPBackend(DDPBase):
|
|||
def train(self):
|
||||
model = self.trainer.model
|
||||
if self.mode == 'ddp':
|
||||
results = self.ddp_train_tmp(process_idx=self.task_idx, mp_queue=None, model=model, is_master=True)
|
||||
results = self.ddp_train(process_idx=self.task_idx, model=model, is_master=True)
|
||||
del os.environ['WORLD_SIZE']
|
||||
return results
|
||||
else:
|
||||
self.ddp_train_tmp(process_idx=self.task_idx, mp_queue=None, model=model)
|
||||
self.ddp_train(process_idx=self.task_idx, model=model)
|
||||
|
||||
def _check_can_spawn_children(self):
|
||||
if self._has_spawned_children:
|
||||
|
@ -168,3 +177,117 @@ class DDPBackend(DDPBase):
|
|||
def get_device_ids(self):
|
||||
device_ids = [self.trainer.root_gpu]
|
||||
return device_ids
|
||||
|
||||
def training_step(self, args):
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
with torch.cuda.amp.autocast():
|
||||
output = self.trainer.model(*args)
|
||||
else:
|
||||
output = self.trainer.model(*args)
|
||||
return output
|
||||
|
||||
def validation_step(self, args):
|
||||
output = self.training_step(args)
|
||||
return output
|
||||
|
||||
def test_step(self, args):
|
||||
output = self.training_step(args)
|
||||
return output
|
||||
|
||||
def barrier(self, name: str = None):
|
||||
if torch_distrib.is_initialized():
|
||||
torch_distrib.barrier()
|
||||
|
||||
def early_stopping_should_stop(self, pl_module):
|
||||
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
|
||||
dist.all_reduce(stop, op=dist.reduce_op.SUM)
|
||||
dist.barrier()
|
||||
should_stop = stop == self.trainer.world_size
|
||||
return should_stop
|
||||
|
||||
def broadcast(self, obj, src=0):
|
||||
return self.dist.broadcast(obj)
|
||||
|
||||
def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
|
||||
"""
|
||||
Entry point for ddp
|
||||
|
||||
Args:
|
||||
process_idx:
|
||||
mp_queue: multiprocessing queue
|
||||
model:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
seed = os.environ.get("PL_GLOBAL_SEED")
|
||||
if seed is not None:
|
||||
seed_everything(int(seed))
|
||||
|
||||
# offset the process id if requested
|
||||
process_idx = process_idx + proc_offset
|
||||
|
||||
# show progressbar only on progress_rank 0
|
||||
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
|
||||
self.trainer.progress_bar_callback.disable()
|
||||
|
||||
# determine which process we are and world size
|
||||
self.set_world_ranks(process_idx)
|
||||
|
||||
# set warning rank
|
||||
rank_zero_only.rank = self.trainer.global_rank
|
||||
|
||||
# set up server using proc 0's ip address
|
||||
# try to init for 20 times at max in case ports are taken
|
||||
# where to store ip_table
|
||||
model.trainer = self.trainer
|
||||
model.init_ddp_connection(
|
||||
self.trainer.global_rank,
|
||||
self.trainer.world_size,
|
||||
self.trainer.is_slurm_managing_tasks
|
||||
)
|
||||
|
||||
# call setup after the ddp process has connected
|
||||
self.trainer.call_setup_hook(model)
|
||||
|
||||
# on world_size=0 let everyone know training is starting
|
||||
if self.trainer.is_global_zero and not torch.distributed.is_initialized():
|
||||
log.info('-' * 100)
|
||||
log.info(f'distributed_backend={self.trainer.distributed_backend}')
|
||||
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
|
||||
log.info('-' * 100)
|
||||
|
||||
# call sync_bn before .cuda(), configure_apex and configure_ddp
|
||||
if self.trainer.sync_batchnorm:
|
||||
model = model.configure_sync_batchnorm(model)
|
||||
|
||||
# move the model to the correct device
|
||||
self.model_to_device(model, process_idx, is_master)
|
||||
|
||||
# CHOOSE OPTIMIZER
|
||||
# allow for lr schedulers as well
|
||||
self.setup_optimizers(model)
|
||||
|
||||
# set model properties before going into wrapper
|
||||
self.trainer.model_connector.copy_trainer_model_properties(model)
|
||||
|
||||
# 16-bit
|
||||
model = self.trainer.precision_connector.connect(model)
|
||||
|
||||
# device ids change depending on the DDP setup
|
||||
device_ids = self.get_device_ids()
|
||||
|
||||
# allow user to configure ddp
|
||||
model = model.configure_ddp(model, device_ids)
|
||||
|
||||
# set up training routine
|
||||
self.trainer.train_loop.setup_training(model)
|
||||
|
||||
# train or test
|
||||
results = self.train_or_test()
|
||||
|
||||
# clean up memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.trainer.global_rank == 0:
|
||||
return results
|
||||
|
|
Loading…
Reference in New Issue