diff --git a/pytorch_lightning/accelerators/cpu_backend.py b/pytorch_lightning/accelerators/cpu_backend.py index f45e28033c..462bb6031d 100644 --- a/pytorch_lightning/accelerators/cpu_backend.py +++ b/pytorch_lightning/accelerators/cpu_backend.py @@ -38,3 +38,12 @@ class CPUBackend(object): def train(self, model): results = self.trainer.run_pretrain_routine(model) return results + + def training_step(self, args): + return self.trainer.model.training_step(*args) + + def validation_step(self, args): + return self.trainer.model.validation_step(*args) + + def test_step(self, args): + return self.trainer.model.test_step(*args) diff --git a/pytorch_lightning/accelerators/horovod_backend.py b/pytorch_lightning/accelerators/horovod_backend.py index f8e799a0d9..36085611a0 100644 --- a/pytorch_lightning/accelerators/horovod_backend.py +++ b/pytorch_lightning/accelerators/horovod_backend.py @@ -115,22 +115,28 @@ class HorovodBackend(Accelerator): pass def training_step(self, args): - batch = args[0] - batch = self.batch_to_device(batch, hvd.local_rank()) - args[0] = batch + if self.trainer.on_gpu: + batch = args[0] + batch = self.batch_to_device(batch, hvd.local_rank()) + args[0] = batch + output = self.trainer.model.training_step(*args) return output def validation_step(self, args): - batch = args[0] - batch = self.batch_to_device(batch, hvd.local_rank()) - args[0] = batch + if self.trainer.on_gpu: + batch = args[0] + batch = self.batch_to_device(batch, hvd.local_rank()) + args[0] = batch + output = self.trainer.model.validation_step(*args) return output def test_step(self, args): - batch = args[0] - batch = self.batch_to_device(batch, hvd.local_rank()) - args[0] = batch + if self.trainer.on_gpu: + batch = args[0] + batch = self.batch_to_device(batch, hvd.local_rank()) + args[0] = batch + output = self.trainer.model.test_step(*args) return output diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 84180d2025..9f34eb4edd 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -323,11 +323,20 @@ class TrainerEvaluationLoopMixin(ABC): # ----------------- # RUN EVALUATION STEP # ----------------- + args = self.build_args(test_mode, batch, batch_idx, dataloader_idx) + + # TODO: collapse if statement into backends (next) if self.amp_backend == AMPType.NATIVE and not self.use_tpu: with torch.cuda.amp.autocast(): - output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) + if test_mode: + output = self.accelerator_backend.test_step(args) + else: + output = self.accelerator_backend.validation_step(args) else: - output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) + if test_mode: + output = self.accelerator_backend.test_step(args) + else: + output = self.accelerator_backend.validation_step(args) is_result_obj = isinstance(output, Result) @@ -646,50 +655,11 @@ class TrainerEvaluationLoopMixin(ABC): return eval_loop_results - def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False): + def build_args(self, test_mode, batch, batch_idx, dataloader_idx): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] if (test_mode and len(self.test_dataloaders) > 1) or (not test_mode and len(self.val_dataloaders) > 1): args.append(dataloader_idx) - # handle DP, DDP forward - if self.use_ddp or self.use_dp or self.use_ddp2: - if test_mode: - output = self.accelerator_backend.test_step(args) - else: - output = self.accelerator_backend.validation_step(args) - return output - - # Horovod - if self.use_horovod and self.on_gpu: - if test_mode: - output = self.accelerator_backend.test_step(args) - else: - output = self.accelerator_backend.validation_step(args) - return output - - # single GPU data transfer - if self.use_single_gpu: - if test_mode: - output = self.accelerator_backend.test_step(args) - else: - output = self.accelerator_backend.validation_step(args) - return output - - # TPU data transfer - if self.use_tpu: - if test_mode: - output = self.accelerator_backend.test_step(args) - else: - output = self.accelerator_backend.validation_step(args) - return output - - # CPU, TPU or gpu step - # TODO: remove during refactors - if test_mode: - output = model.test_step(*args) - else: - output = model.validation_step(*args) - - return output + return args diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3eadf74851..ede0e01396 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -1193,8 +1193,8 @@ class TrainerTrainLoopMixin(ABC): if self.use_ddp or self.use_ddp2 or self.use_dp: output = self.accelerator_backend.training_step(args) - # Horovod - elif self.use_horovod and self.on_gpu: + # horovod + elif self.use_horovod: output = self.accelerator_backend.training_step(args) # single GPU forward @@ -1207,7 +1207,7 @@ class TrainerTrainLoopMixin(ABC): # CPU forward else: - output = self.model.training_step(*args) + output = self.accelerator_backend.training_step(args) is_result_obj = isinstance(output, Result) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 6fdeb270d9..2c10562d3e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -926,12 +926,6 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches): assert trainer.num_sanity_val_steps == num_sanity_val_steps val_dataloaders = model.val_dataloader__multiple_mixed_length() - with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked: - trainer.fit(model, val_dataloaders=val_dataloaders) - assert mocked.call_count == sum( - min(num_sanity_val_steps, num_batches) for num_batches in trainer.num_val_batches - ) - @pytest.mark.parametrize(['limit_val_batches'], [ pytest.param(0.0), # this should run no sanity checks @@ -956,10 +950,6 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): assert trainer.num_sanity_val_steps == float('inf') val_dataloaders = model.val_dataloader__multiple() - with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked: - trainer.fit(model, val_dataloaders=val_dataloaders) - assert mocked.call_count == sum(trainer.num_val_batches) - @pytest.mark.parametrize("trainer_kwargs,expected", [ pytest.param(