Refactor 1: moved tpu xxx_step to backend (#3118)

* moved tpu training_step

* refactored eval step

* refactored eval step

* refactored eval step
This commit is contained in:
William Falcon 2020-08-24 07:02:06 -04:00 committed by GitHub
parent 45e7491dcc
commit 3c88b0dd83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 68 additions and 37 deletions

View File

@ -0,0 +1,15 @@
import torch
from typing import Any
from pytorch_lightning.utilities.apply_func import move_data_to_device
class Accelerator(object):
def __init__(self, trainer):
self.trainer = trainer
def batch_to_device(self, batch: Any, device: torch.device):
model = self.trainer.get_model()
if model is not None:
return model.transfer_batch_to_device(batch, device)
return move_data_to_device(batch, device)

View File

@ -21,6 +21,7 @@ from pytorch_lightning import _logger as log
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.accelerators.base_backend import Accelerator
try:
import torch_xla
@ -32,10 +33,10 @@ else:
XLA_AVAILABLE = True
class TPUBackend(object):
class TPUBackend(Accelerator):
def __init__(self, trainer):
self.trainer = trainer
super().__init__(trainer)
self.start_method = None
self.mp_queue = None
@ -117,6 +118,47 @@ class TPUBackend(object):
# persist info in spawn
trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)
def training_step(self, batch, args):
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.training_step(*args)
return output
def validation_step(self, batch, args):
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.validation_step(*args)
return output
def test_step(self, batch, args):
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.test_step(*args)
return output
def to_device(self, batch):
"""
Transfers the data to the TPU.
Args:
batch: A tensor or collection of tensors.
tpu_id: The id of the TPU core. If omitted, the first available core is chosen.
Return:
the tensor on the TPU device.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
"""
if not XLA_AVAILABLE:
raise MisconfigurationException(
'Requested to transfer batch to TPU but XLA is not available.'
' Are you sure this machine has TPUs?'
)
device = xm.xla_device(self.trainer.tpu_id)
return self.batch_to_device(batch, device)
def __save_end_of_training_weights(self, model: LightningModule, trainer):
# when training ends on these platforms dump weights to get out of the main process
if trainer.on_colab_kaggle:

View File

@ -134,28 +134,6 @@ class TrainerDPMixin(ABC):
m.global_rank = self.global_rank
m.local_rank = self.local_rank
def transfer_batch_to_tpu(self, batch: Any, tpu_id: Optional[int] = None):
"""
Transfers the data to the TPU.
Args:
batch: A tensor or collection of tensors.
tpu_id: The id of the TPU core. If omitted, the first available core is chosen.
Return:
the tensor on the TPU device.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
"""
if not XLA_AVAILABLE:
raise MisconfigurationException(
'Requested to transfer batch to TPU but XLA is not available.'
' Are you sure this machine has TPUs?'
)
device = xm.xla_device(tpu_id)
return self.__transfer_batch_to_device(batch, device)
def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None):
"""
Transfers the data to the GPU.

View File

@ -191,6 +191,7 @@ class TrainerEvaluationLoopMixin(ABC):
on_validation_end: Callable
on_test_start: Callable
on_test_end: Callable
accelerator_backend: ...
@abstractmethod
def copy_trainer_model_properties(self, *args):
@ -204,10 +205,6 @@ class TrainerEvaluationLoopMixin(ABC):
def is_overridden(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def transfer_batch_to_tpu(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def transfer_batch_to_gpu(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@ -677,10 +674,14 @@ class TrainerEvaluationLoopMixin(ABC):
# TPU data transfer
if self.use_tpu:
batch = self.transfer_batch_to_tpu(batch, self.tpu_id)
args[0] = batch
if test_mode:
output = self.accelerator_backend.test_step(batch, args)
else:
output = self.accelerator_backend.validation_step(batch, args)
return output
# CPU, TPU or gpu step
# TODO: remove during refactors
if test_mode:
output = model.test_step(*args)
else:

View File

@ -258,6 +258,7 @@ class TrainerTrainLoopMixin(ABC):
state: TrainerState
amp_backend: AMPType
on_tpu: bool
accelerator_backend: ...
# Callback system
callbacks: List[Callback]
@ -290,10 +291,6 @@ class TrainerTrainLoopMixin(ABC):
def transfer_batch_to_gpu(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def transfer_batch_to_tpu(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def clip_gradients(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@ -1217,9 +1214,7 @@ class TrainerTrainLoopMixin(ABC):
# TPU support
elif self.use_tpu:
batch = self.transfer_batch_to_tpu(batch, self.tpu_id)
args[0] = batch
output = self.model.training_step(*args)
output = self.accelerator_backend.training_step(batch, args)
# CPU forward
else: