diff --git a/pytorch_lightning/accelerators/ddp2_backend.py b/pytorch_lightning/accelerators/ddp2_backend.py index 51251846f7..7b7258a152 100644 --- a/pytorch_lightning/accelerators/ddp2_backend.py +++ b/pytorch_lightning/accelerators/ddp2_backend.py @@ -156,3 +156,15 @@ class DDP2Backend(object): # clean up memory torch.cuda.empty_cache() + + def training_step(self, args): + output = self.trainer.model(*args) + return output + + def validation_step(self, args): + output = self.training_step(args) + return output + + def test_step(self, args): + output = self.training_step(args) + return output diff --git a/pytorch_lightning/accelerators/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py index 6866a66543..70f293a577 100644 --- a/pytorch_lightning/accelerators/ddp_backend.py +++ b/pytorch_lightning/accelerators/ddp_backend.py @@ -235,6 +235,18 @@ class DDPBackend(object): if self.trainer.global_rank == 0 and self.trainer.distributed_backend not in ['ddp_spawn', 'ddp_cpu']: return results + def training_step(self, args): + output = self.trainer.model(*args) + return output + + def validation_step(self, args): + output = self.training_step(args) + return output + + def test_step(self, args): + output = self.training_step(args) + return output + def _check_can_spawn_children(self): if self._has_spawned_children: raise RuntimeError( diff --git a/pytorch_lightning/accelerators/ddp_spawn_backend.py b/pytorch_lightning/accelerators/ddp_spawn_backend.py index ee49540579..30c8cf44f8 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_spawn_backend.py @@ -162,3 +162,15 @@ class DDPSpawnBackend(object): # clean up memory torch.cuda.empty_cache() + + def training_step(self, args): + output = self.trainer.model(*args) + return output + + def validation_step(self, args): + output = self.training_step(args) + return output + + def test_step(self, args): + output = self.training_step(args) + return output diff --git a/pytorch_lightning/accelerators/dp_backend.py b/pytorch_lightning/accelerators/dp_backend.py index 4013b939b3..70c307cf14 100644 --- a/pytorch_lightning/accelerators/dp_backend.py +++ b/pytorch_lightning/accelerators/dp_backend.py @@ -98,10 +98,21 @@ class DataParallelBackend(object): return results def teardown(self): - # replace the original fwd function self.trainer.model.forward = self.model_autocast_original_forward + def training_step(self, args): + output = self.trainer.model(*args) + return output + + def validation_step(self, args): + output = self.training_step(args) + return output + + def test_step(self, args): + output = self.training_step(args) + return output + def reinit_scheduler_properties(self, optimizers: list, schedulers: list): """ Reinitialize optimizer.step properties added by schedulers diff --git a/pytorch_lightning/accelerators/tpu_backend.py b/pytorch_lightning/accelerators/tpu_backend.py index e16d16c8c5..1522e51afe 100644 --- a/pytorch_lightning/accelerators/tpu_backend.py +++ b/pytorch_lightning/accelerators/tpu_backend.py @@ -118,19 +118,22 @@ class TPUBackend(Accelerator): # persist info in spawn trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) - def training_step(self, batch, args): + def training_step(self, args): + batch = args[0] batch = self.to_device(batch) args[0] = batch output = self.trainer.model.training_step(*args) return output - def validation_step(self, batch, args): + def validation_step(self, args): + batch = args[0] batch = self.to_device(batch) args[0] = batch output = self.trainer.model.validation_step(*args) return output - def test_step(self, batch, args): + def test_step(self, args): + batch = args[0] batch = self.to_device(batch) args[0] = batch output = self.trainer.model.test_step(*args) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index a9e2d476cd..ef48f5a81b 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -655,7 +655,10 @@ class TrainerEvaluationLoopMixin(ABC): # handle DP, DDP forward if self.use_ddp or self.use_dp or self.use_ddp2: - output = model(*args) + if test_mode: + output = self.accelerator_backend.test_step(args) + else: + output = self.accelerator_backend.validation_step(args) return output # Horovod @@ -675,9 +678,9 @@ class TrainerEvaluationLoopMixin(ABC): # TPU data transfer if self.use_tpu: if test_mode: - output = self.accelerator_backend.test_step(batch, args) + output = self.accelerator_backend.test_step(args) else: - output = self.accelerator_backend.validation_step(batch, args) + output = self.accelerator_backend.validation_step(args) return output # CPU, TPU or gpu step diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 66bd46a034..cec3ff6e61 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -1191,7 +1191,7 @@ class TrainerTrainLoopMixin(ABC): # distributed forward if self.use_ddp or self.use_ddp2 or self.use_dp: - output = self.model(*args) + output = self.accelerator_backend.training_step(args) # Horovod elif self.use_horovod and self.on_gpu: @@ -1214,7 +1214,7 @@ class TrainerTrainLoopMixin(ABC): # TPU support elif self.use_tpu: - output = self.accelerator_backend.training_step(batch, args) + output = self.accelerator_backend.training_step(args) # CPU forward else: