diff --git a/pytorch_lightning/accelerators/cpu_backend.py b/pytorch_lightning/accelerators/cpu_backend.py index 80be22b528..88c7329e6c 100644 --- a/pytorch_lightning/accelerators/cpu_backend.py +++ b/pytorch_lightning/accelerators/cpu_backend.py @@ -50,7 +50,17 @@ class CPUBackend(Accelerator): return output def validation_step(self, args): - return self.trainer.model.validation_step(*args) + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.trainer.model.validation_step(*args) + else: + output = self.trainer.model.validation_step(*args) + return output def test_step(self, args): - return self.trainer.model.test_step(*args) + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.trainer.model.test_step(*args) + else: + output = self.trainer.model.test_step(*args) + return output diff --git a/pytorch_lightning/accelerators/gpu_backend.py b/pytorch_lightning/accelerators/gpu_backend.py index 983be14a29..13a100f9d2 100644 --- a/pytorch_lightning/accelerators/gpu_backend.py +++ b/pytorch_lightning/accelerators/gpu_backend.py @@ -69,6 +69,15 @@ class GPUBackend(Accelerator): return output def validation_step(self, args): + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.__validation_step(args) + else: + output = self.__validation_step(args) + + return output + + def __validation_step(self, args): batch = args[0] batch = self.to_device(batch) args[0] = batch @@ -76,6 +85,15 @@ class GPUBackend(Accelerator): return output def test_step(self, args): + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.__test_step(args) + else: + output = self.__test_step(args) + + return output + + def __test_step(self, args): batch = args[0] batch = self.to_device(batch) args[0] = batch diff --git a/pytorch_lightning/accelerators/horovod_backend.py b/pytorch_lightning/accelerators/horovod_backend.py index dc3e9db03a..f8cc5fc5df 100644 --- a/pytorch_lightning/accelerators/horovod_backend.py +++ b/pytorch_lightning/accelerators/horovod_backend.py @@ -134,7 +134,12 @@ class HorovodBackend(Accelerator): batch = self.batch_to_device(batch, hvd.local_rank()) args[0] = batch - output = self.trainer.model.validation_step(*args) + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.trainer.model.validation_step(*args) + else: + output = self.trainer.model.validation_step(*args) + return output def test_step(self, args): @@ -143,5 +148,9 @@ class HorovodBackend(Accelerator): batch = self.batch_to_device(batch, hvd.local_rank()) args[0] = batch - output = self.trainer.model.test_step(*args) + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.trainer.model.test_step(*args) + else: + output = self.trainer.model.test_step(*args) return output diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 8c49ee4ed7..8843e03917 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -303,18 +303,10 @@ class TrainerEvaluationLoopMixin(ABC): # ----------------- args = self.build_args(test_mode, batch, batch_idx, dataloader_idx) - # TODO: collapse if statement into backends (next) - if self.amp_backend == AMPType.NATIVE and not self.use_tpu: - with torch.cuda.amp.autocast(): - if test_mode: - output = self.accelerator_backend.test_step(args) - else: - output = self.accelerator_backend.validation_step(args) + if test_mode: + output = self.accelerator_backend.test_step(args) else: - if test_mode: - output = self.accelerator_backend.test_step(args) - else: - output = self.accelerator_backend.validation_step(args) + output = self.accelerator_backend.validation_step(args) is_result_obj = isinstance(output, Result)