From 82d1128966633e7c9987d54c1331ab34538e8f13 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Aug 2020 20:26:39 -0400 Subject: [PATCH] eval step scaling factor (#3136) --- pytorch_lightning/accelerators/cpu_backend.py | 14 ++++++++++++-- pytorch_lightning/accelerators/gpu_backend.py | 18 ++++++++++++++++++ .../accelerators/horovod_backend.py | 13 +++++++++++-- pytorch_lightning/trainer/evaluation_loop.py | 14 +++----------- 4 files changed, 44 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/accelerators/cpu_backend.py b/pytorch_lightning/accelerators/cpu_backend.py index 80be22b528..88c7329e6c 100644 --- a/pytorch_lightning/accelerators/cpu_backend.py +++ b/pytorch_lightning/accelerators/cpu_backend.py @@ -50,7 +50,17 @@ class CPUBackend(Accelerator): return output def validation_step(self, args): - return self.trainer.model.validation_step(*args) + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.trainer.model.validation_step(*args) + else: + output = self.trainer.model.validation_step(*args) + return output def test_step(self, args): - return self.trainer.model.test_step(*args) + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.trainer.model.test_step(*args) + else: + output = self.trainer.model.test_step(*args) + return output diff --git a/pytorch_lightning/accelerators/gpu_backend.py b/pytorch_lightning/accelerators/gpu_backend.py index 983be14a29..13a100f9d2 100644 --- a/pytorch_lightning/accelerators/gpu_backend.py +++ b/pytorch_lightning/accelerators/gpu_backend.py @@ -69,6 +69,15 @@ class GPUBackend(Accelerator): return output def validation_step(self, args): + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.__validation_step(args) + else: + output = self.__validation_step(args) + + return output + + def __validation_step(self, args): batch = args[0] batch = self.to_device(batch) args[0] = batch @@ -76,6 +85,15 @@ class GPUBackend(Accelerator): return output def test_step(self, args): + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.__test_step(args) + else: + output = self.__test_step(args) + + return output + + def __test_step(self, args): batch = args[0] batch = self.to_device(batch) args[0] = batch diff --git a/pytorch_lightning/accelerators/horovod_backend.py b/pytorch_lightning/accelerators/horovod_backend.py index dc3e9db03a..f8cc5fc5df 100644 --- a/pytorch_lightning/accelerators/horovod_backend.py +++ b/pytorch_lightning/accelerators/horovod_backend.py @@ -134,7 +134,12 @@ class HorovodBackend(Accelerator): batch = self.batch_to_device(batch, hvd.local_rank()) args[0] = batch - output = self.trainer.model.validation_step(*args) + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.trainer.model.validation_step(*args) + else: + output = self.trainer.model.validation_step(*args) + return output def test_step(self, args): @@ -143,5 +148,9 @@ class HorovodBackend(Accelerator): batch = self.batch_to_device(batch, hvd.local_rank()) args[0] = batch - output = self.trainer.model.test_step(*args) + if self.trainer.amp_backend == AMPType.NATIVE: + with torch.cuda.amp.autocast(): + output = self.trainer.model.test_step(*args) + else: + output = self.trainer.model.test_step(*args) return output diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 8c49ee4ed7..8843e03917 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -303,18 +303,10 @@ class TrainerEvaluationLoopMixin(ABC): # ----------------- args = self.build_args(test_mode, batch, batch_idx, dataloader_idx) - # TODO: collapse if statement into backends (next) - if self.amp_backend == AMPType.NATIVE and not self.use_tpu: - with torch.cuda.amp.autocast(): - if test_mode: - output = self.accelerator_backend.test_step(args) - else: - output = self.accelerator_backend.validation_step(args) + if test_mode: + output = self.accelerator_backend.test_step(args) else: - if test_mode: - output = self.accelerator_backend.test_step(args) - else: - output = self.accelerator_backend.validation_step(args) + output = self.accelerator_backend.validation_step(args) is_result_obj = isinstance(output, Result)