Refactor 1: moved tpu xxx_step to backend (#3118)
* moved tpu training_step * refactored eval step * refactored eval step * refactored eval step
This commit is contained in:
parent
45e7491dcc
commit
3c88b0dd83
|
@ -0,0 +1,15 @@
|
|||
import torch
|
||||
from typing import Any
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
|
||||
|
||||
class Accelerator(object):
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
def batch_to_device(self, batch: Any, device: torch.device):
|
||||
model = self.trainer.get_model()
|
||||
if model is not None:
|
||||
return model.transfer_batch_to_device(batch, device)
|
||||
return move_data_to_device(batch, device)
|
|
@ -21,6 +21,7 @@ from pytorch_lightning import _logger as log
|
|||
from pytorch_lightning.core import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.accelerators.base_backend import Accelerator
|
||||
|
||||
try:
|
||||
import torch_xla
|
||||
|
@ -32,10 +33,10 @@ else:
|
|||
XLA_AVAILABLE = True
|
||||
|
||||
|
||||
class TPUBackend(object):
|
||||
class TPUBackend(Accelerator):
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
super().__init__(trainer)
|
||||
self.start_method = None
|
||||
self.mp_queue = None
|
||||
|
||||
|
@ -117,6 +118,47 @@ class TPUBackend(object):
|
|||
# persist info in spawn
|
||||
trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)
|
||||
|
||||
def training_step(self, batch, args):
|
||||
batch = self.to_device(batch)
|
||||
args[0] = batch
|
||||
output = self.trainer.model.training_step(*args)
|
||||
return output
|
||||
|
||||
def validation_step(self, batch, args):
|
||||
batch = self.to_device(batch)
|
||||
args[0] = batch
|
||||
output = self.trainer.model.validation_step(*args)
|
||||
return output
|
||||
|
||||
def test_step(self, batch, args):
|
||||
batch = self.to_device(batch)
|
||||
args[0] = batch
|
||||
output = self.trainer.model.test_step(*args)
|
||||
return output
|
||||
|
||||
def to_device(self, batch):
|
||||
"""
|
||||
Transfers the data to the TPU.
|
||||
|
||||
Args:
|
||||
batch: A tensor or collection of tensors.
|
||||
tpu_id: The id of the TPU core. If omitted, the first available core is chosen.
|
||||
|
||||
Return:
|
||||
the tensor on the TPU device.
|
||||
|
||||
See Also:
|
||||
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
|
||||
"""
|
||||
if not XLA_AVAILABLE:
|
||||
raise MisconfigurationException(
|
||||
'Requested to transfer batch to TPU but XLA is not available.'
|
||||
' Are you sure this machine has TPUs?'
|
||||
)
|
||||
device = xm.xla_device(self.trainer.tpu_id)
|
||||
|
||||
return self.batch_to_device(batch, device)
|
||||
|
||||
def __save_end_of_training_weights(self, model: LightningModule, trainer):
|
||||
# when training ends on these platforms dump weights to get out of the main process
|
||||
if trainer.on_colab_kaggle:
|
||||
|
|
|
@ -134,28 +134,6 @@ class TrainerDPMixin(ABC):
|
|||
m.global_rank = self.global_rank
|
||||
m.local_rank = self.local_rank
|
||||
|
||||
def transfer_batch_to_tpu(self, batch: Any, tpu_id: Optional[int] = None):
|
||||
"""
|
||||
Transfers the data to the TPU.
|
||||
|
||||
Args:
|
||||
batch: A tensor or collection of tensors.
|
||||
tpu_id: The id of the TPU core. If omitted, the first available core is chosen.
|
||||
|
||||
Return:
|
||||
the tensor on the TPU device.
|
||||
|
||||
See Also:
|
||||
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
|
||||
"""
|
||||
if not XLA_AVAILABLE:
|
||||
raise MisconfigurationException(
|
||||
'Requested to transfer batch to TPU but XLA is not available.'
|
||||
' Are you sure this machine has TPUs?'
|
||||
)
|
||||
device = xm.xla_device(tpu_id)
|
||||
return self.__transfer_batch_to_device(batch, device)
|
||||
|
||||
def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None):
|
||||
"""
|
||||
Transfers the data to the GPU.
|
||||
|
|
|
@ -191,6 +191,7 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
on_validation_end: Callable
|
||||
on_test_start: Callable
|
||||
on_test_end: Callable
|
||||
accelerator_backend: ...
|
||||
|
||||
@abstractmethod
|
||||
def copy_trainer_model_properties(self, *args):
|
||||
|
@ -204,10 +205,6 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
def is_overridden(self, *args):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
@abstractmethod
|
||||
def transfer_batch_to_tpu(self, *args):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
@abstractmethod
|
||||
def transfer_batch_to_gpu(self, *args):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
@ -677,10 +674,14 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
|
||||
# TPU data transfer
|
||||
if self.use_tpu:
|
||||
batch = self.transfer_batch_to_tpu(batch, self.tpu_id)
|
||||
args[0] = batch
|
||||
if test_mode:
|
||||
output = self.accelerator_backend.test_step(batch, args)
|
||||
else:
|
||||
output = self.accelerator_backend.validation_step(batch, args)
|
||||
return output
|
||||
|
||||
# CPU, TPU or gpu step
|
||||
# TODO: remove during refactors
|
||||
if test_mode:
|
||||
output = model.test_step(*args)
|
||||
else:
|
||||
|
|
|
@ -258,6 +258,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
state: TrainerState
|
||||
amp_backend: AMPType
|
||||
on_tpu: bool
|
||||
accelerator_backend: ...
|
||||
|
||||
# Callback system
|
||||
callbacks: List[Callback]
|
||||
|
@ -290,10 +291,6 @@ class TrainerTrainLoopMixin(ABC):
|
|||
def transfer_batch_to_gpu(self, *args):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
@abstractmethod
|
||||
def transfer_batch_to_tpu(self, *args):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
@abstractmethod
|
||||
def clip_gradients(self, *args):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
@ -1217,9 +1214,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
|
||||
# TPU support
|
||||
elif self.use_tpu:
|
||||
batch = self.transfer_batch_to_tpu(batch, self.tpu_id)
|
||||
args[0] = batch
|
||||
output = self.model.training_step(*args)
|
||||
output = self.accelerator_backend.training_step(batch, args)
|
||||
|
||||
# CPU forward
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue