ref: ddp backend refactor (3) (#3208)
* ddp backend refactor * ddp backend refactor
This commit is contained in:
parent
a8daf914f8
commit
6bae404bed
|
@ -36,8 +36,10 @@ class CPUBackend(Accelerator):
|
|||
self.trainer.optimizers = optimizers
|
||||
self.trainer.lr_schedulers = lr_schedulers
|
||||
self.trainer.optimizer_frequencies = optimizer_frequencies
|
||||
self.trainer.model = model
|
||||
|
||||
def train(self, model):
|
||||
def train(self):
|
||||
model = self.trainer.model
|
||||
results = self.trainer.run_pretrain_routine(model)
|
||||
return results
|
||||
|
||||
|
|
|
@ -46,6 +46,8 @@ class DDP2Backend(Accelerator):
|
|||
def setup(self, model):
|
||||
self._resolve_task_idx()
|
||||
|
||||
self.trainer.model = model
|
||||
|
||||
def _resolve_task_idx(self):
|
||||
if self.trainer.is_slurm_managing_tasks:
|
||||
self.task_idx = int(os.environ['SLURM_LOCALID'])
|
||||
|
@ -57,7 +59,8 @@ class DDP2Backend(Accelerator):
|
|||
m = 'ddp2 only works in SLURM or via torchelastic with the WORLD_SIZE, LOCAL_RANK, GROUP_RANK flags'
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
def train(self, model):
|
||||
def train(self):
|
||||
model = self.trainer.model
|
||||
self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model)
|
||||
|
||||
def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0):
|
||||
|
|
|
@ -57,13 +57,16 @@ class DDPBackend(Accelerator):
|
|||
elif self.mode == 'torchelastic_ddp':
|
||||
self.__torchelastic_setup()
|
||||
|
||||
self.trainer.model = model
|
||||
|
||||
def __slurm_setup(self):
|
||||
self.task_idx = int(os.environ['SLURM_LOCALID'])
|
||||
|
||||
def __torchelastic_setup(self):
|
||||
self.task_idx = int(os.environ['LOCAL_RANK'])
|
||||
|
||||
def train(self, model):
|
||||
def train(self):
|
||||
model = self.trainer.model
|
||||
self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model)
|
||||
|
||||
def spawn_ddp_children(self, model):
|
||||
|
|
|
@ -41,7 +41,11 @@ class DDPSpawnBackend(Accelerator):
|
|||
smp = mp.get_context('spawn')
|
||||
self.mp_queue = smp.SimpleQueue()
|
||||
|
||||
def train(self, model):
|
||||
self.trainer.model = model
|
||||
|
||||
def train(self):
|
||||
model = self.trainer.model
|
||||
|
||||
# train in children process
|
||||
mp.spawn(self.ddp_train, nprocs=self.nprocs, args=(self.mp_queue, model,))
|
||||
|
||||
|
|
|
@ -46,9 +46,11 @@ class GPUBackend(Accelerator):
|
|||
|
||||
if self.trainer.amp_backend == AMPType.APEX:
|
||||
model = self._setup_nvidia_apex(model)
|
||||
return model
|
||||
|
||||
def train(self, model):
|
||||
self.trainer.model = model
|
||||
|
||||
def train(self):
|
||||
model = self.trainer.model
|
||||
results = self.trainer.run_pretrain_routine(model)
|
||||
return results
|
||||
|
||||
|
|
|
@ -54,6 +54,8 @@ class TPUBackend(Accelerator):
|
|||
smp = mp.get_context(self.start_method)
|
||||
self.mp_queue = smp.SimpleQueue()
|
||||
|
||||
self.trainer.model = model
|
||||
|
||||
def teardown(self, model):
|
||||
# restore main state with best weights
|
||||
best_path = self.mp_queue.get()
|
||||
|
@ -75,8 +77,8 @@ class TPUBackend(Accelerator):
|
|||
self.__load_weights_on_main_process()
|
||||
return results
|
||||
|
||||
def train(self, model: LightningModule):
|
||||
self.trainer.model = model
|
||||
def train(self):
|
||||
model = self.trainer.model
|
||||
|
||||
# train
|
||||
if self.trainer.tpu_id is not None:
|
||||
|
|
|
@ -1040,26 +1040,26 @@ class Trainer(
|
|||
if self.use_ddp2:
|
||||
self.accelerator_backend = DDP2Backend(self)
|
||||
self.accelerator_backend.setup(model)
|
||||
results = self.accelerator_backend.train(model)
|
||||
results = self.accelerator_backend.train()
|
||||
self.accelerator_backend.teardown()
|
||||
|
||||
elif use_slurm_ddp:
|
||||
self.accelerator_backend = DDPBackend(self, mode='slurm_ddp')
|
||||
self.accelerator_backend.setup(model)
|
||||
results = self.accelerator_backend.train(model)
|
||||
results = self.accelerator_backend.train()
|
||||
self.accelerator_backend.teardown()
|
||||
|
||||
elif use_torchelastic_ddp:
|
||||
self.accelerator_backend = DDPBackend(self, mode='torchelastic_ddp')
|
||||
self.accelerator_backend.setup(model)
|
||||
results = self.accelerator_backend.train(model)
|
||||
results = self.accelerator_backend.train()
|
||||
self.accelerator_backend.teardown()
|
||||
|
||||
# regular ddp using .spawn
|
||||
elif use_ddp_spawn:
|
||||
self.accelerator_backend = DDPSpawnBackend(self, nprocs=self.num_processes)
|
||||
self.accelerator_backend.setup(model)
|
||||
results = self.accelerator_backend.train(model)
|
||||
results = self.accelerator_backend.train()
|
||||
self.accelerator_backend.teardown()
|
||||
|
||||
# ddp
|
||||
|
@ -1082,20 +1082,20 @@ class Trainer(
|
|||
|
||||
elif self.use_single_gpu:
|
||||
self.accelerator_backend = GPUBackend(self)
|
||||
model = self.accelerator_backend.setup(model)
|
||||
results = self.accelerator_backend.train(model)
|
||||
self.accelerator_backend.setup(model)
|
||||
results = self.accelerator_backend.train()
|
||||
self.accelerator_backend.teardown()
|
||||
|
||||
elif self.use_tpu:
|
||||
self.accelerator_backend = TPUBackend(self)
|
||||
self.accelerator_backend.setup(model)
|
||||
self.accelerator_backend.train(model)
|
||||
self.accelerator_backend.train()
|
||||
self.accelerator_backend.teardown(model)
|
||||
|
||||
else:
|
||||
self.accelerator_backend = CPUBackend(self)
|
||||
self.accelerator_backend.setup(model)
|
||||
results = self.accelerator_backend.train(model)
|
||||
results = self.accelerator_backend.train()
|
||||
self.accelerator_backend.teardown()
|
||||
|
||||
# hook
|
||||
|
|
Loading…
Reference in New Issue