diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py index abc065cd39..f4d31213c7 100644 --- a/pytorch_lightning/accelerators/gpu_accelerator.py +++ b/pytorch_lightning/accelerators/gpu_accelerator.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch @@ -66,53 +66,25 @@ class GPUAccelerator(Accelerator): results = self.train_or_test() return results - def training_step(self, args): + def _step(self, model_step: Callable, args): + args[0] = self.to_device(args[0]) + if self.trainer.amp_backend == AMPType.NATIVE: with torch.cuda.amp.autocast(): - output = self.__training_step(args) + output = model_step(*args) else: - output = self.__training_step(args) + output = model_step(*args) return output - def __training_step(self, args): - batch = args[0] - batch = self.to_device(batch) - args[0] = batch - output = self.trainer.model.training_step(*args) - return output + def training_step(self, args): + return self._step(self.trainer.model.training_step, args) def validation_step(self, args): - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.__validation_step(args) - else: - output = self.__validation_step(args) - - return output - - def __validation_step(self, args): - batch = args[0] - batch = self.to_device(batch) - args[0] = batch - output = self.trainer.model.validation_step(*args) - return output + return self._step(self.trainer.model.validation_step, args) def test_step(self, args): - if self.trainer.amp_backend == AMPType.NATIVE: - with torch.cuda.amp.autocast(): - output = self.__test_step(args) - else: - output = self.__test_step(args) - - return output - - def __test_step(self, args): - batch = args[0] - batch = self.to_device(batch) - args[0] = batch - output = self.trainer.model.test_step(*args) - return output + return self._step(self.trainer.model.test_step, args) def to_device(self, batch): gpu_id = 0 diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index a7752e42a9..74fd201df8 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -14,7 +14,7 @@ import io import os import re -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.multiprocessing as mp @@ -145,26 +145,18 @@ class TPUAccelerator(Accelerator): # persist info in spawn self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) + def _step(self, model_step: Callable, args): + args[0] = self.to_device(args[0]) + return model_step(*args) + def training_step(self, args): - batch = args[0] - batch = self.to_device(batch) - args[0] = batch - output = self.trainer.model.training_step(*args) - return output + return self._step(self.trainer.model.training_step, args) def validation_step(self, args): - batch = args[0] - batch = self.to_device(batch) - args[0] = batch - output = self.trainer.model.validation_step(*args) - return output + return self._step(self.trainer.model.validation_step, args) def test_step(self, args): - batch = args[0] - batch = self.to_device(batch) - args[0] = batch - output = self.trainer.model.test_step(*args) - return output + return self._step(self.trainer.model.test_step, args) def process_dataloader(self, dataloader): device = xm.xla_device(self.trainer.tpu_id)