refactored horovod backend (#3122)

This commit is contained in:
William Falcon 2020-08-24 11:13:49 -04:00 committed by GitHub
parent 2d42ec008f
commit 18160b81b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 5 deletions

View File

@ -113,3 +113,24 @@ class HorovodBackend(Accelerator):
def teardown(self):
pass
def training_step(self, args):
batch = args[0]
batch = self.batch_to_device(batch, hvd.local_rank())
args[0] = batch
output = self.trainer.model.training_step(*args)
return output
def validation_step(self, args):
batch = args[0]
batch = self.batch_to_device(batch, hvd.local_rank())
args[0] = batch
output = self.trainer.model.validation_step(*args)
return output
def test_step(self, args):
batch = args[0]
batch = self.batch_to_device(batch, hvd.local_rank())
args[0] = batch
output = self.trainer.model.test_step(*args)
return output

View File

@ -663,8 +663,11 @@ class TrainerEvaluationLoopMixin(ABC):
# Horovod
if self.use_horovod and self.on_gpu:
batch = self.transfer_batch_to_gpu(batch, hvd.local_rank())
args[0] = batch
if test_mode:
output = self.accelerator_backend.test_step(args)
else:
output = self.accelerator_backend.validation_step(args)
return output
# single GPU data transfer
if self.use_single_gpu:

View File

@ -1195,9 +1195,7 @@ class TrainerTrainLoopMixin(ABC):
# Horovod
elif self.use_horovod and self.on_gpu:
batch = self.transfer_batch_to_gpu(batch, hvd.local_rank())
args[0] = batch
output = self.model.training_step(*args)
output = self.accelerator_backend.training_step(args)
# single GPU forward
elif self.use_single_gpu: