data transfer model hook (+ refactor) (#1756)
* refactor and added hook variant a variant b add test revert rename add changelog docs * resolve merge duplication * overridden typo * fix test * tpu id * raise if TPU not available * re-use apply_to_collection function for parsing collections * comment * make utility function available to user * documentation * move changelog entry to top * fix tpu transfer call * fix call * remove hardcoded string * improve test * call model hook by default * Apply suggestions from code review * rename utility function Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
ade3f36b7a
commit
8211256c46
|
@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Speed up single-core TPU training by loading data using `ParallelLoader` ([#2033](https://github.com/PyTorchLightning/pytorch-lightning/pull/2033))
|
||||
|
||||
- Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756)).
|
||||
|
||||
### Changed
|
||||
|
||||
- Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729))
|
||||
|
|
|
@ -60,8 +60,8 @@ else:
|
|||
'Trainer',
|
||||
'LightningModule',
|
||||
'Callback',
|
||||
'data_loader'
|
||||
'seed_everything'
|
||||
'data_loader',
|
||||
'seed_everything',
|
||||
]
|
||||
|
||||
# necessary for regular bolts imports. Skip exception since bolts is not always installed
|
||||
|
|
|
@ -3,6 +3,8 @@ from typing import Any
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from pytorch_lightning.utilities import move_data_to_device
|
||||
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
@ -153,3 +155,48 @@ class ModelHooks(torch.nn.Module):
|
|||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
|
||||
"""
|
||||
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
|
||||
wrapped in a custom data structure.
|
||||
|
||||
The data types listed below (and any arbitrary nesting of them) are supported out of the box:
|
||||
|
||||
- :class:`torch.Tensor`
|
||||
- :class:`list`
|
||||
- :class:`dict`
|
||||
- :class:`tuple`
|
||||
- ``torchtext.data.Batch`` (COMING SOON)
|
||||
|
||||
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
|
||||
|
||||
Example::
|
||||
|
||||
def transfer_batch_to_device(self, batch, device)
|
||||
if isinstance(batch, CustomBatch):
|
||||
# move all tensors in your custom data structure to the device
|
||||
batch.samples = batch.samples.to(device)
|
||||
batch.targets = batch.targets.to(device)
|
||||
else:
|
||||
batch = super().transfer_batch_to_device(data, device)
|
||||
return batch
|
||||
|
||||
Args:
|
||||
batch: A batch of data that needs to be transferred to a new device.
|
||||
device: The target device as defined in PyTorch.
|
||||
|
||||
Returns:
|
||||
A reference to the data on the new device.
|
||||
|
||||
Note:
|
||||
This hook should only transfer the data and not modify it, nor should it move the data to
|
||||
any other device than the one passed in as argument (unless you know what you are doing).
|
||||
The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the
|
||||
batch and determines the target devices.
|
||||
|
||||
See Also:
|
||||
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
|
||||
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
|
||||
"""
|
||||
return move_data_to_device(batch, device)
|
||||
|
|
|
@ -18,6 +18,7 @@ from pytorch_lightning.overrides.data_parallel import (
|
|||
LightningDistributedDataParallel,
|
||||
LightningDataParallel,
|
||||
)
|
||||
from pytorch_lightning.utilities import move_data_to_device
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
|
||||
|
@ -99,58 +100,50 @@ class TrainerDPMixin(ABC):
|
|||
m.tpu_local_core_rank = self.tpu_local_core_rank
|
||||
m.tpu_global_core_rank = self.tpu_global_core_rank
|
||||
|
||||
def transfer_batch_to_tpu(self, batch):
|
||||
return self.__transfer_data_to_device(batch, device='tpu')
|
||||
def transfer_batch_to_tpu(self, batch: Any, tpu_id: Optional[int] = None):
|
||||
"""
|
||||
Transfers the data to the TPU.
|
||||
|
||||
def transfer_batch_to_gpu(self, batch, gpu_id):
|
||||
return self.__transfer_data_to_device(batch, device='gpu', gpu_id=gpu_id)
|
||||
Args:
|
||||
batch: A tensor or collection of tensors.
|
||||
tpu_id: The id of the TPU core. If omitted, the first available core is chosen.
|
||||
|
||||
def __transfer_data_to_device(self, batch, device, gpu_id=None):
|
||||
if device == 'tpu' and XLA_AVAILABLE:
|
||||
# base case: object can be directly moved using `to`
|
||||
if callable(getattr(batch, 'to', None)):
|
||||
xla_device = xm.xla_device(self.tpu_id) if self.tpu_id is not None else xm.xla_device()
|
||||
return batch.to(xla_device)
|
||||
Return:
|
||||
the tensor on the TPU device.
|
||||
|
||||
if device == 'gpu':
|
||||
# base case: object can be directly moved using `cuda` or `to`
|
||||
if callable(getattr(batch, 'cuda', None)):
|
||||
# non_blocking will be ignored if tensor is not pinned.
|
||||
# so we can always set it to True
|
||||
return batch.cuda(gpu_id, non_blocking=True)
|
||||
See Also:
|
||||
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
|
||||
"""
|
||||
if not XLA_AVAILABLE:
|
||||
raise MisconfigurationException(
|
||||
'Requested to transfer batch to TPU but XLA is not available.'
|
||||
' Are you sure this machine has TPUs?'
|
||||
)
|
||||
device = xm.xla_device(tpu_id)
|
||||
return self.__transfer_batch_to_device(batch, device)
|
||||
|
||||
if callable(getattr(batch, 'to', None)):
|
||||
# non_blocking will be ignored if tensor is not pinned.
|
||||
# so we can always set it to True
|
||||
return batch.to(torch.device('cuda', gpu_id), non_blocking=True)
|
||||
def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None):
|
||||
"""
|
||||
Transfers the data to the GPU.
|
||||
|
||||
# when list
|
||||
if isinstance(batch, list):
|
||||
for i, x in enumerate(batch):
|
||||
batch[i] = self.__transfer_data_to_device(x, device, gpu_id)
|
||||
return batch
|
||||
Args:
|
||||
batch: A tensor or collection of tensors.
|
||||
gpu_id: The id of the GPU device. If omitted, the first available GPU is chosen.
|
||||
|
||||
# when tuple
|
||||
if isinstance(batch, tuple):
|
||||
# when namedtuple
|
||||
if hasattr(batch, '_fields'):
|
||||
elem_type = type(batch)
|
||||
return elem_type(*(self.__transfer_data_to_device(x, device, gpu_id) for x in batch))
|
||||
else:
|
||||
batch = list(batch)
|
||||
for i, x in enumerate(batch):
|
||||
batch[i] = self.__transfer_data_to_device(x, device, gpu_id)
|
||||
return tuple(batch)
|
||||
Return:
|
||||
the tensor on the GPU device.
|
||||
|
||||
# when dict
|
||||
if isinstance(batch, dict):
|
||||
for k, v in batch.items():
|
||||
batch[k] = self.__transfer_data_to_device(v, device, gpu_id)
|
||||
See Also:
|
||||
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
|
||||
"""
|
||||
device = torch.device('cuda', gpu_id)
|
||||
return self.__transfer_batch_to_device(batch, device)
|
||||
|
||||
return batch
|
||||
|
||||
# nothing matches, return the value as is without transform
|
||||
return batch
|
||||
def __transfer_batch_to_device(self, batch: Any, device: torch.device):
|
||||
model = self.get_model()
|
||||
if model is not None:
|
||||
return model.transfer_batch_to_device(batch, device)
|
||||
return move_data_to_device(batch, device)
|
||||
|
||||
def single_gpu_train(self, model):
|
||||
model.cuda(self.root_gpu)
|
||||
|
|
|
@ -434,7 +434,7 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
|
||||
# TPU data transfer
|
||||
if self.use_tpu:
|
||||
batch = self.transfer_batch_to_tpu(batch)
|
||||
batch = self.transfer_batch_to_tpu(batch, self.tpu_id)
|
||||
args[0] = batch
|
||||
|
||||
# CPU, TPU or gpu step
|
||||
|
|
|
@ -753,7 +753,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
|
||||
# TPU support
|
||||
elif self.use_tpu:
|
||||
batch = self.transfer_batch_to_tpu(batch)
|
||||
batch = self.transfer_batch_to_tpu(batch, self.tpu_id)
|
||||
args[0] = batch
|
||||
output = self.model.training_step(*args)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
"""General utilities"""
|
||||
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from collections import Mapping, Sequence
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any:
|
||||
"""
|
||||
|
@ -34,3 +36,24 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable
|
|||
|
||||
# data is neither of dtype, nor a collection
|
||||
return data
|
||||
|
||||
|
||||
def move_data_to_device(batch: Any, device: torch.device):
|
||||
"""
|
||||
Transfers a collection of tensors to the given device.
|
||||
|
||||
Args:
|
||||
batch: A tensor or collection of tensors. See :func:`apply_to_collection`
|
||||
for a list of supported collection types.
|
||||
device: The device to which tensors should be moved
|
||||
|
||||
Return:
|
||||
the same collection but with all contained tensors residing on the new device.
|
||||
|
||||
See Also:
|
||||
- :meth:`torch.Tensor.to`
|
||||
- :class:`torch.device`
|
||||
"""
|
||||
def to(tensor):
|
||||
return tensor.to(device, non_blocking=True)
|
||||
return apply_to_collection(batch, dtype=torch.Tensor, function=to)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
@ -68,3 +70,37 @@ def test_training_epoch_end_metrics_collection(tmpdir):
|
|||
# metrics are kept after each epoch
|
||||
for i in range(num_epochs):
|
||||
assert metrics[f'epoch_metric_{i}'] == i
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
def test_transfer_batch_hook():
|
||||
|
||||
class CustomBatch:
|
||||
|
||||
def __init__(self, data):
|
||||
self.samples = data[0]
|
||||
self.targets = data[1]
|
||||
|
||||
class CurrentTestModel(EvalModelTemplate):
|
||||
|
||||
hook_called = False
|
||||
|
||||
def transfer_batch_to_device(self, data, device):
|
||||
self.hook_called = True
|
||||
if isinstance(data, CustomBatch):
|
||||
data.samples = data.samples.to(device)
|
||||
data.targets = data.targets.to(device)
|
||||
else:
|
||||
data = super().transfer_batch_to_device(data, device)
|
||||
return data
|
||||
|
||||
model = CurrentTestModel()
|
||||
batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long)))
|
||||
|
||||
trainer = Trainer()
|
||||
# running .fit() would require us to implement custom data loaders, we mock the model reference instead
|
||||
trainer.get_model = MagicMock(return_value=model)
|
||||
batch_gpu = trainer.transfer_batch_to_gpu(batch, 0)
|
||||
expected = torch.device('cuda', 0)
|
||||
assert model.hook_called
|
||||
assert batch_gpu.samples.device == batch_gpu.targets.device == expected
|
||||
|
|
Loading…
Reference in New Issue