ref: clean up ddp before final fix (#3817)
* ref: clean up ddp before final fix * ref: clean up ddp before final fix * ref: clean up ddp before final fix * ref: clean up ddp before final fix * ref: clean up ddp before final fix
This commit is contained in:
parent
0838c6bfce
commit
ed1450a293
|
@ -154,7 +154,7 @@ class AcceleratorConnector:
|
|||
accelerator_backend = accelerators.DDPCPUSpawnBackend(self.trainer, nprocs=self.trainer.num_processes)
|
||||
|
||||
elif self.trainer.distributed_backend == "ddp":
|
||||
accelerator_backend = accelerators.DDPBackend(self.trainer, mode='ddp')
|
||||
accelerator_backend = accelerators.DDPBackend(self.trainer)
|
||||
|
||||
elif self.trainer.use_dp:
|
||||
accelerator_backend = accelerators.DataParallelBackend(self.trainer)
|
||||
|
|
|
@ -43,30 +43,21 @@ else:
|
|||
|
||||
class DDPBackend(Accelerator):
|
||||
|
||||
def __init__(self, trainer, mode: str = 'ddp'):
|
||||
def __init__(self, trainer):
|
||||
super().__init__(trainer)
|
||||
self.task_idx = None
|
||||
self._has_spawned_children = False
|
||||
self.mode = mode
|
||||
self.dist = LightningDistributed()
|
||||
|
||||
def setup(self, model):
|
||||
if self.mode == 'ddp':
|
||||
self.__ddp_script_mode_setup()
|
||||
elif self.mode == 'slurm_ddp':
|
||||
self.__slurm_setup()
|
||||
elif self.mode == 'torchelastic_ddp':
|
||||
self.__torchelastic_setup()
|
||||
|
||||
# first track model
|
||||
self.trainer.model = model
|
||||
|
||||
def __slurm_setup(self):
|
||||
self.task_idx = int(os.environ['SLURM_LOCALID'])
|
||||
# start the other scripts
|
||||
self._call_children_scripts()
|
||||
|
||||
def __torchelastic_setup(self):
|
||||
self.task_idx = int(os.environ['LOCAL_RANK'])
|
||||
def _call_children_scripts(self):
|
||||
|
||||
def __ddp_script_mode_setup(self):
|
||||
assert self.trainer.global_rank == 0
|
||||
self._check_can_spawn_children()
|
||||
self._has_spawned_children = True
|
||||
|
@ -137,12 +128,9 @@ class DDPBackend(Accelerator):
|
|||
|
||||
def train(self):
|
||||
model = self.trainer.model
|
||||
if self.mode == 'ddp':
|
||||
results = self.ddp_train(process_idx=self.task_idx, model=model, is_master=True)
|
||||
del os.environ['WORLD_SIZE']
|
||||
return results
|
||||
else:
|
||||
self.ddp_train(process_idx=self.task_idx, model=model)
|
||||
results = self.ddp_train(process_idx=self.task_idx, model=model, is_master=True)
|
||||
del os.environ['WORLD_SIZE']
|
||||
return results
|
||||
|
||||
def _check_can_spawn_children(self):
|
||||
if self._has_spawned_children:
|
||||
|
@ -288,5 +276,4 @@ class DDPBackend(Accelerator):
|
|||
# clean up memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.trainer.global_rank == 0:
|
||||
return results
|
||||
return results
|
||||
|
|
Loading…
Reference in New Issue