refactored horovod backend (#3122)
This commit is contained in:
parent
2d42ec008f
commit
18160b81b5
|
@ -113,3 +113,24 @@ class HorovodBackend(Accelerator):
|
|||
|
||||
def teardown(self):
|
||||
pass
|
||||
|
||||
def training_step(self, args):
|
||||
batch = args[0]
|
||||
batch = self.batch_to_device(batch, hvd.local_rank())
|
||||
args[0] = batch
|
||||
output = self.trainer.model.training_step(*args)
|
||||
return output
|
||||
|
||||
def validation_step(self, args):
|
||||
batch = args[0]
|
||||
batch = self.batch_to_device(batch, hvd.local_rank())
|
||||
args[0] = batch
|
||||
output = self.trainer.model.validation_step(*args)
|
||||
return output
|
||||
|
||||
def test_step(self, args):
|
||||
batch = args[0]
|
||||
batch = self.batch_to_device(batch, hvd.local_rank())
|
||||
args[0] = batch
|
||||
output = self.trainer.model.test_step(*args)
|
||||
return output
|
||||
|
|
|
@ -663,8 +663,11 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
|
||||
# Horovod
|
||||
if self.use_horovod and self.on_gpu:
|
||||
batch = self.transfer_batch_to_gpu(batch, hvd.local_rank())
|
||||
args[0] = batch
|
||||
if test_mode:
|
||||
output = self.accelerator_backend.test_step(args)
|
||||
else:
|
||||
output = self.accelerator_backend.validation_step(args)
|
||||
return output
|
||||
|
||||
# single GPU data transfer
|
||||
if self.use_single_gpu:
|
||||
|
|
|
@ -1195,9 +1195,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
|
||||
# Horovod
|
||||
elif self.use_horovod and self.on_gpu:
|
||||
batch = self.transfer_batch_to_gpu(batch, hvd.local_rank())
|
||||
args[0] = batch
|
||||
output = self.model.training_step(*args)
|
||||
output = self.accelerator_backend.training_step(args)
|
||||
|
||||
# single GPU forward
|
||||
elif self.use_single_gpu:
|
||||
|
|
Loading…
Reference in New Issue