From 18160b81b5c39806b09b3bbdee8b71c316471a43 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Aug 2020 11:13:49 -0400 Subject: [PATCH] refactored horovod backend (#3122) --- .../accelerators/horovod_backend.py | 21 +++++++++++++++++++ pytorch_lightning/trainer/evaluation_loop.py | 7 +++++-- pytorch_lightning/trainer/training_loop.py | 4 +--- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/accelerators/horovod_backend.py b/pytorch_lightning/accelerators/horovod_backend.py index a5d760af60..f8e799a0d9 100644 --- a/pytorch_lightning/accelerators/horovod_backend.py +++ b/pytorch_lightning/accelerators/horovod_backend.py @@ -113,3 +113,24 @@ class HorovodBackend(Accelerator): def teardown(self): pass + + def training_step(self, args): + batch = args[0] + batch = self.batch_to_device(batch, hvd.local_rank()) + args[0] = batch + output = self.trainer.model.training_step(*args) + return output + + def validation_step(self, args): + batch = args[0] + batch = self.batch_to_device(batch, hvd.local_rank()) + args[0] = batch + output = self.trainer.model.validation_step(*args) + return output + + def test_step(self, args): + batch = args[0] + batch = self.batch_to_device(batch, hvd.local_rank()) + args[0] = batch + output = self.trainer.model.test_step(*args) + return output diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index ecee841382..84180d2025 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -663,8 +663,11 @@ class TrainerEvaluationLoopMixin(ABC): # Horovod if self.use_horovod and self.on_gpu: - batch = self.transfer_batch_to_gpu(batch, hvd.local_rank()) - args[0] = batch + if test_mode: + output = self.accelerator_backend.test_step(args) + else: + output = self.accelerator_backend.validation_step(args) + return output # single GPU data transfer if self.use_single_gpu: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ecbfee0c45..3eadf74851 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -1195,9 +1195,7 @@ class TrainerTrainLoopMixin(ABC): # Horovod elif self.use_horovod and self.on_gpu: - batch = self.transfer_batch_to_gpu(batch, hvd.local_rank()) - args[0] = batch - output = self.model.training_step(*args) + output = self.accelerator_backend.training_step(args) # single GPU forward elif self.use_single_gpu: