refactored ddp backend forward (#3119)
This commit is contained in:
parent
3c88b0dd83
commit
527b9dca36
|
@ -156,3 +156,15 @@ class DDP2Backend(object):
|
|||
|
||||
# clean up memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def training_step(self, args):
|
||||
output = self.trainer.model(*args)
|
||||
return output
|
||||
|
||||
def validation_step(self, args):
|
||||
output = self.training_step(args)
|
||||
return output
|
||||
|
||||
def test_step(self, args):
|
||||
output = self.training_step(args)
|
||||
return output
|
||||
|
|
|
@ -235,6 +235,18 @@ class DDPBackend(object):
|
|||
if self.trainer.global_rank == 0 and self.trainer.distributed_backend not in ['ddp_spawn', 'ddp_cpu']:
|
||||
return results
|
||||
|
||||
def training_step(self, args):
|
||||
output = self.trainer.model(*args)
|
||||
return output
|
||||
|
||||
def validation_step(self, args):
|
||||
output = self.training_step(args)
|
||||
return output
|
||||
|
||||
def test_step(self, args):
|
||||
output = self.training_step(args)
|
||||
return output
|
||||
|
||||
def _check_can_spawn_children(self):
|
||||
if self._has_spawned_children:
|
||||
raise RuntimeError(
|
||||
|
|
|
@ -162,3 +162,15 @@ class DDPSpawnBackend(object):
|
|||
|
||||
# clean up memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def training_step(self, args):
|
||||
output = self.trainer.model(*args)
|
||||
return output
|
||||
|
||||
def validation_step(self, args):
|
||||
output = self.training_step(args)
|
||||
return output
|
||||
|
||||
def test_step(self, args):
|
||||
output = self.training_step(args)
|
||||
return output
|
||||
|
|
|
@ -98,10 +98,21 @@ class DataParallelBackend(object):
|
|||
return results
|
||||
|
||||
def teardown(self):
|
||||
|
||||
# replace the original fwd function
|
||||
self.trainer.model.forward = self.model_autocast_original_forward
|
||||
|
||||
def training_step(self, args):
|
||||
output = self.trainer.model(*args)
|
||||
return output
|
||||
|
||||
def validation_step(self, args):
|
||||
output = self.training_step(args)
|
||||
return output
|
||||
|
||||
def test_step(self, args):
|
||||
output = self.training_step(args)
|
||||
return output
|
||||
|
||||
def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
|
||||
"""
|
||||
Reinitialize optimizer.step properties added by schedulers
|
||||
|
|
|
@ -118,19 +118,22 @@ class TPUBackend(Accelerator):
|
|||
# persist info in spawn
|
||||
trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)
|
||||
|
||||
def training_step(self, batch, args):
|
||||
def training_step(self, args):
|
||||
batch = args[0]
|
||||
batch = self.to_device(batch)
|
||||
args[0] = batch
|
||||
output = self.trainer.model.training_step(*args)
|
||||
return output
|
||||
|
||||
def validation_step(self, batch, args):
|
||||
def validation_step(self, args):
|
||||
batch = args[0]
|
||||
batch = self.to_device(batch)
|
||||
args[0] = batch
|
||||
output = self.trainer.model.validation_step(*args)
|
||||
return output
|
||||
|
||||
def test_step(self, batch, args):
|
||||
def test_step(self, args):
|
||||
batch = args[0]
|
||||
batch = self.to_device(batch)
|
||||
args[0] = batch
|
||||
output = self.trainer.model.test_step(*args)
|
||||
|
|
|
@ -655,7 +655,10 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
|
||||
# handle DP, DDP forward
|
||||
if self.use_ddp or self.use_dp or self.use_ddp2:
|
||||
output = model(*args)
|
||||
if test_mode:
|
||||
output = self.accelerator_backend.test_step(args)
|
||||
else:
|
||||
output = self.accelerator_backend.validation_step(args)
|
||||
return output
|
||||
|
||||
# Horovod
|
||||
|
@ -675,9 +678,9 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# TPU data transfer
|
||||
if self.use_tpu:
|
||||
if test_mode:
|
||||
output = self.accelerator_backend.test_step(batch, args)
|
||||
output = self.accelerator_backend.test_step(args)
|
||||
else:
|
||||
output = self.accelerator_backend.validation_step(batch, args)
|
||||
output = self.accelerator_backend.validation_step(args)
|
||||
return output
|
||||
|
||||
# CPU, TPU or gpu step
|
||||
|
|
|
@ -1191,7 +1191,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
|
||||
# distributed forward
|
||||
if self.use_ddp or self.use_ddp2 or self.use_dp:
|
||||
output = self.model(*args)
|
||||
output = self.accelerator_backend.training_step(args)
|
||||
|
||||
# Horovod
|
||||
elif self.use_horovod and self.on_gpu:
|
||||
|
@ -1214,7 +1214,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
|
||||
# TPU support
|
||||
elif self.use_tpu:
|
||||
output = self.accelerator_backend.training_step(batch, args)
|
||||
output = self.accelerator_backend.training_step(args)
|
||||
|
||||
# CPU forward
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue