refactored ddp backend forward (#3119)

This commit is contained in:
William Falcon 2020-08-24 07:33:14 -04:00 committed by GitHub
parent 3c88b0dd83
commit 527b9dca36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 62 additions and 9 deletions

View File

@ -156,3 +156,15 @@ class DDP2Backend(object):
# clean up memory
torch.cuda.empty_cache()
def training_step(self, args):
output = self.trainer.model(*args)
return output
def validation_step(self, args):
output = self.training_step(args)
return output
def test_step(self, args):
output = self.training_step(args)
return output

View File

@ -235,6 +235,18 @@ class DDPBackend(object):
if self.trainer.global_rank == 0 and self.trainer.distributed_backend not in ['ddp_spawn', 'ddp_cpu']:
return results
def training_step(self, args):
output = self.trainer.model(*args)
return output
def validation_step(self, args):
output = self.training_step(args)
return output
def test_step(self, args):
output = self.training_step(args)
return output
def _check_can_spawn_children(self):
if self._has_spawned_children:
raise RuntimeError(

View File

@ -162,3 +162,15 @@ class DDPSpawnBackend(object):
# clean up memory
torch.cuda.empty_cache()
def training_step(self, args):
output = self.trainer.model(*args)
return output
def validation_step(self, args):
output = self.training_step(args)
return output
def test_step(self, args):
output = self.training_step(args)
return output

View File

@ -98,10 +98,21 @@ class DataParallelBackend(object):
return results
def teardown(self):
# replace the original fwd function
self.trainer.model.forward = self.model_autocast_original_forward
def training_step(self, args):
output = self.trainer.model(*args)
return output
def validation_step(self, args):
output = self.training_step(args)
return output
def test_step(self, args):
output = self.training_step(args)
return output
def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
"""
Reinitialize optimizer.step properties added by schedulers

View File

@ -118,19 +118,22 @@ class TPUBackend(Accelerator):
# persist info in spawn
trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)
def training_step(self, batch, args):
def training_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.training_step(*args)
return output
def validation_step(self, batch, args):
def validation_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.validation_step(*args)
return output
def test_step(self, batch, args):
def test_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.test_step(*args)

View File

@ -655,7 +655,10 @@ class TrainerEvaluationLoopMixin(ABC):
# handle DP, DDP forward
if self.use_ddp or self.use_dp or self.use_ddp2:
output = model(*args)
if test_mode:
output = self.accelerator_backend.test_step(args)
else:
output = self.accelerator_backend.validation_step(args)
return output
# Horovod
@ -675,9 +678,9 @@ class TrainerEvaluationLoopMixin(ABC):
# TPU data transfer
if self.use_tpu:
if test_mode:
output = self.accelerator_backend.test_step(batch, args)
output = self.accelerator_backend.test_step(args)
else:
output = self.accelerator_backend.validation_step(batch, args)
output = self.accelerator_backend.validation_step(args)
return output
# CPU, TPU or gpu step

View File

@ -1191,7 +1191,7 @@ class TrainerTrainLoopMixin(ABC):
# distributed forward
if self.use_ddp or self.use_ddp2 or self.use_dp:
output = self.model(*args)
output = self.accelerator_backend.training_step(args)
# Horovod
elif self.use_horovod and self.on_gpu:
@ -1214,7 +1214,7 @@ class TrainerTrainLoopMixin(ABC):
# TPU support
elif self.use_tpu:
output = self.accelerator_backend.training_step(batch, args)
output = self.accelerator_backend.training_step(args)
# CPU forward
else: