diff --git a/pytorch_lightning/accelerators/horovod_backend.py b/pytorch_lightning/accelerators/horovod_backend.py index a5d760af60..f8e799a0d9 100644 --- a/pytorch_lightning/accelerators/horovod_backend.py +++ b/pytorch_lightning/accelerators/horovod_backend.py @@ -113,3 +113,24 @@ class HorovodBackend(Accelerator): def teardown(self): pass + + def training_step(self, args): + batch = args[0] + batch = self.batch_to_device(batch, hvd.local_rank()) + args[0] = batch + output = self.trainer.model.training_step(*args) + return output + + def validation_step(self, args): + batch = args[0] + batch = self.batch_to_device(batch, hvd.local_rank()) + args[0] = batch + output = self.trainer.model.validation_step(*args) + return output + + def test_step(self, args): + batch = args[0] + batch = self.batch_to_device(batch, hvd.local_rank()) + args[0] = batch + output = self.trainer.model.test_step(*args) + return output diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index ecee841382..84180d2025 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -663,8 +663,11 @@ class TrainerEvaluationLoopMixin(ABC): # Horovod if self.use_horovod and self.on_gpu: - batch = self.transfer_batch_to_gpu(batch, hvd.local_rank()) - args[0] = batch + if test_mode: + output = self.accelerator_backend.test_step(args) + else: + output = self.accelerator_backend.validation_step(args) + return output # single GPU data transfer if self.use_single_gpu: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ecbfee0c45..3eadf74851 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -1195,9 +1195,7 @@ class TrainerTrainLoopMixin(ABC): # Horovod elif self.use_horovod and self.on_gpu: - batch = self.transfer_batch_to_gpu(batch, hvd.local_rank()) - args[0] = batch - output = self.model.training_step(*args) + output = self.accelerator_backend.training_step(args) # single GPU forward elif self.use_single_gpu: