diff --git a/pytorch_lightning/accelerators/base_backend.py b/pytorch_lightning/accelerators/base_backend.py new file mode 100644 index 0000000000..e7468d9ac7 --- /dev/null +++ b/pytorch_lightning/accelerators/base_backend.py @@ -0,0 +1,15 @@ +import torch +from typing import Any +from pytorch_lightning.utilities.apply_func import move_data_to_device + + +class Accelerator(object): + + def __init__(self, trainer): + self.trainer = trainer + + def batch_to_device(self, batch: Any, device: torch.device): + model = self.trainer.get_model() + if model is not None: + return model.transfer_batch_to_device(batch, device) + return move_data_to_device(batch, device) diff --git a/pytorch_lightning/accelerators/tpu_backend.py b/pytorch_lightning/accelerators/tpu_backend.py index e879af3f5c..e16d16c8c5 100644 --- a/pytorch_lightning/accelerators/tpu_backend.py +++ b/pytorch_lightning/accelerators/tpu_backend.py @@ -21,6 +21,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.core import LightningModule from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.accelerators.base_backend import Accelerator try: import torch_xla @@ -32,10 +33,10 @@ else: XLA_AVAILABLE = True -class TPUBackend(object): +class TPUBackend(Accelerator): def __init__(self, trainer): - self.trainer = trainer + super().__init__(trainer) self.start_method = None self.mp_queue = None @@ -117,6 +118,47 @@ class TPUBackend(object): # persist info in spawn trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) + def training_step(self, batch, args): + batch = self.to_device(batch) + args[0] = batch + output = self.trainer.model.training_step(*args) + return output + + def validation_step(self, batch, args): + batch = self.to_device(batch) + args[0] = batch + output = self.trainer.model.validation_step(*args) + return output + + def test_step(self, batch, args): + batch = self.to_device(batch) + args[0] = batch + output = self.trainer.model.test_step(*args) + return output + + def to_device(self, batch): + """ + Transfers the data to the TPU. + + Args: + batch: A tensor or collection of tensors. + tpu_id: The id of the TPU core. If omitted, the first available core is chosen. + + Return: + the tensor on the TPU device. + + See Also: + - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` + """ + if not XLA_AVAILABLE: + raise MisconfigurationException( + 'Requested to transfer batch to TPU but XLA is not available.' + ' Are you sure this machine has TPUs?' + ) + device = xm.xla_device(self.trainer.tpu_id) + + return self.batch_to_device(batch, device) + def __save_end_of_training_weights(self, model: LightningModule, trainer): # when training ends on these platforms dump weights to get out of the main process if trainer.on_colab_kaggle: diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 9b810352cf..8eff7dfa9d 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -134,28 +134,6 @@ class TrainerDPMixin(ABC): m.global_rank = self.global_rank m.local_rank = self.local_rank - def transfer_batch_to_tpu(self, batch: Any, tpu_id: Optional[int] = None): - """ - Transfers the data to the TPU. - - Args: - batch: A tensor or collection of tensors. - tpu_id: The id of the TPU core. If omitted, the first available core is chosen. - - Return: - the tensor on the TPU device. - - See Also: - - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` - """ - if not XLA_AVAILABLE: - raise MisconfigurationException( - 'Requested to transfer batch to TPU but XLA is not available.' - ' Are you sure this machine has TPUs?' - ) - device = xm.xla_device(tpu_id) - return self.__transfer_batch_to_device(batch, device) - def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None): """ Transfers the data to the GPU. diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index a05c8a68b1..a9e2d476cd 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -191,6 +191,7 @@ class TrainerEvaluationLoopMixin(ABC): on_validation_end: Callable on_test_start: Callable on_test_end: Callable + accelerator_backend: ... @abstractmethod def copy_trainer_model_properties(self, *args): @@ -204,10 +205,6 @@ class TrainerEvaluationLoopMixin(ABC): def is_overridden(self, *args): """Warning: this is just empty shell for code implemented in other class.""" - @abstractmethod - def transfer_batch_to_tpu(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - @abstractmethod def transfer_batch_to_gpu(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @@ -677,10 +674,14 @@ class TrainerEvaluationLoopMixin(ABC): # TPU data transfer if self.use_tpu: - batch = self.transfer_batch_to_tpu(batch, self.tpu_id) - args[0] = batch + if test_mode: + output = self.accelerator_backend.test_step(batch, args) + else: + output = self.accelerator_backend.validation_step(batch, args) + return output # CPU, TPU or gpu step + # TODO: remove during refactors if test_mode: output = model.test_step(*args) else: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index bb5ed82650..66bd46a034 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -258,6 +258,7 @@ class TrainerTrainLoopMixin(ABC): state: TrainerState amp_backend: AMPType on_tpu: bool + accelerator_backend: ... # Callback system callbacks: List[Callback] @@ -290,10 +291,6 @@ class TrainerTrainLoopMixin(ABC): def transfer_batch_to_gpu(self, *args): """Warning: this is just empty shell for code implemented in other class.""" - @abstractmethod - def transfer_batch_to_tpu(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - @abstractmethod def clip_gradients(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @@ -1217,9 +1214,7 @@ class TrainerTrainLoopMixin(ABC): # TPU support elif self.use_tpu: - batch = self.transfer_batch_to_tpu(batch, self.tpu_id) - args[0] = batch - output = self.model.training_step(*args) + output = self.accelerator_backend.training_step(batch, args) # CPU forward else: