data transfer model hook (+ refactor) (#1756)

* refactor and added hook


variant a


variant b


add test


revert rename


add changelog


docs

* resolve merge duplication

* overridden typo

* fix test

* tpu id

* raise if TPU not available

* re-use apply_to_collection function for parsing collections

* comment

* make utility function available to user

* documentation

* move changelog entry to top

* fix tpu transfer call

* fix call

* remove hardcoded string

* improve test

* call model hook by default

* Apply suggestions from code review

* rename utility function

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2020-06-03 03:45:19 +02:00 committed by GitHub
parent ade3f36b7a
commit 8211256c46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 150 additions and 48 deletions

View File

@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Speed up single-core TPU training by loading data using `ParallelLoader` ([#2033](https://github.com/PyTorchLightning/pytorch-lightning/pull/2033))
- Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756)).
### Changed
- Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729))

View File

@ -60,8 +60,8 @@ else:
'Trainer',
'LightningModule',
'Callback',
'data_loader'
'seed_everything'
'data_loader',
'seed_everything',
]
# necessary for regular bolts imports. Skip exception since bolts is not always installed

View File

@ -3,6 +3,8 @@ from typing import Any
import torch
from torch import Tensor
from torch.optim.optimizer import Optimizer
from pytorch_lightning.utilities import move_data_to_device
try:
from apex import amp
@ -153,3 +155,48 @@ class ModelHooks(torch.nn.Module):
scaled_loss.backward()
else:
loss.backward()
def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any:
"""
Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors
wrapped in a custom data structure.
The data types listed below (and any arbitrary nesting of them) are supported out of the box:
- :class:`torch.Tensor`
- :class:`list`
- :class:`dict`
- :class:`tuple`
- ``torchtext.data.Batch`` (COMING SOON)
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
Example::
def transfer_batch_to_device(self, batch, device)
if isinstance(batch, CustomBatch):
# move all tensors in your custom data structure to the device
batch.samples = batch.samples.to(device)
batch.targets = batch.targets.to(device)
else:
batch = super().transfer_batch_to_device(data, device)
return batch
Args:
batch: A batch of data that needs to be transferred to a new device.
device: The target device as defined in PyTorch.
Returns:
A reference to the data on the new device.
Note:
This hook should only transfer the data and not modify it, nor should it move the data to
any other device than the one passed in as argument (unless you know what you are doing).
The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the
batch and determines the target devices.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
- :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection`
"""
return move_data_to_device(batch, device)

View File

@ -18,6 +18,7 @@ from pytorch_lightning.overrides.data_parallel import (
LightningDistributedDataParallel,
LightningDataParallel,
)
from pytorch_lightning.utilities import move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.distributed import rank_zero_only
@ -99,58 +100,50 @@ class TrainerDPMixin(ABC):
m.tpu_local_core_rank = self.tpu_local_core_rank
m.tpu_global_core_rank = self.tpu_global_core_rank
def transfer_batch_to_tpu(self, batch):
return self.__transfer_data_to_device(batch, device='tpu')
def transfer_batch_to_tpu(self, batch: Any, tpu_id: Optional[int] = None):
"""
Transfers the data to the TPU.
def transfer_batch_to_gpu(self, batch, gpu_id):
return self.__transfer_data_to_device(batch, device='gpu', gpu_id=gpu_id)
Args:
batch: A tensor or collection of tensors.
tpu_id: The id of the TPU core. If omitted, the first available core is chosen.
def __transfer_data_to_device(self, batch, device, gpu_id=None):
if device == 'tpu' and XLA_AVAILABLE:
# base case: object can be directly moved using `to`
if callable(getattr(batch, 'to', None)):
xla_device = xm.xla_device(self.tpu_id) if self.tpu_id is not None else xm.xla_device()
return batch.to(xla_device)
Return:
the tensor on the TPU device.
if device == 'gpu':
# base case: object can be directly moved using `cuda` or `to`
if callable(getattr(batch, 'cuda', None)):
# non_blocking will be ignored if tensor is not pinned.
# so we can always set it to True
return batch.cuda(gpu_id, non_blocking=True)
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
"""
if not XLA_AVAILABLE:
raise MisconfigurationException(
'Requested to transfer batch to TPU but XLA is not available.'
' Are you sure this machine has TPUs?'
)
device = xm.xla_device(tpu_id)
return self.__transfer_batch_to_device(batch, device)
if callable(getattr(batch, 'to', None)):
# non_blocking will be ignored if tensor is not pinned.
# so we can always set it to True
return batch.to(torch.device('cuda', gpu_id), non_blocking=True)
def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None):
"""
Transfers the data to the GPU.
# when list
if isinstance(batch, list):
for i, x in enumerate(batch):
batch[i] = self.__transfer_data_to_device(x, device, gpu_id)
return batch
Args:
batch: A tensor or collection of tensors.
gpu_id: The id of the GPU device. If omitted, the first available GPU is chosen.
# when tuple
if isinstance(batch, tuple):
# when namedtuple
if hasattr(batch, '_fields'):
elem_type = type(batch)
return elem_type(*(self.__transfer_data_to_device(x, device, gpu_id) for x in batch))
else:
batch = list(batch)
for i, x in enumerate(batch):
batch[i] = self.__transfer_data_to_device(x, device, gpu_id)
return tuple(batch)
Return:
the tensor on the GPU device.
# when dict
if isinstance(batch, dict):
for k, v in batch.items():
batch[k] = self.__transfer_data_to_device(v, device, gpu_id)
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
"""
device = torch.device('cuda', gpu_id)
return self.__transfer_batch_to_device(batch, device)
return batch
# nothing matches, return the value as is without transform
return batch
def __transfer_batch_to_device(self, batch: Any, device: torch.device):
model = self.get_model()
if model is not None:
return model.transfer_batch_to_device(batch, device)
return move_data_to_device(batch, device)
def single_gpu_train(self, model):
model.cuda(self.root_gpu)

View File

@ -434,7 +434,7 @@ class TrainerEvaluationLoopMixin(ABC):
# TPU data transfer
if self.use_tpu:
batch = self.transfer_batch_to_tpu(batch)
batch = self.transfer_batch_to_tpu(batch, self.tpu_id)
args[0] = batch
# CPU, TPU or gpu step

View File

@ -753,7 +753,7 @@ class TrainerTrainLoopMixin(ABC):
# TPU support
elif self.use_tpu:
batch = self.transfer_batch_to_tpu(batch)
batch = self.transfer_batch_to_tpu(batch, self.tpu_id)
args[0] = batch
output = self.model.training_step(*args)

View File

@ -1,3 +1,4 @@
"""General utilities"""
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.apply_func import move_data_to_device

View File

@ -1,6 +1,8 @@
from collections import Mapping, Sequence
from typing import Any, Callable, Union
import torch
def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any:
"""
@ -34,3 +36,24 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable
# data is neither of dtype, nor a collection
return data
def move_data_to_device(batch: Any, device: torch.device):
"""
Transfers a collection of tensors to the given device.
Args:
batch: A tensor or collection of tensors. See :func:`apply_to_collection`
for a list of supported collection types.
device: The device to which tensors should be moved
Return:
the same collection but with all contained tensors residing on the new device.
See Also:
- :meth:`torch.Tensor.to`
- :class:`torch.device`
"""
def to(tensor):
return tensor.to(device, non_blocking=True)
return apply_to_collection(batch, dtype=torch.Tensor, function=to)

View File

@ -1,3 +1,5 @@
from unittest.mock import MagicMock
import pytest
import torch
@ -68,3 +70,37 @@ def test_training_epoch_end_metrics_collection(tmpdir):
# metrics are kept after each epoch
for i in range(num_epochs):
assert metrics[f'epoch_metric_{i}'] == i
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_transfer_batch_hook():
class CustomBatch:
def __init__(self, data):
self.samples = data[0]
self.targets = data[1]
class CurrentTestModel(EvalModelTemplate):
hook_called = False
def transfer_batch_to_device(self, data, device):
self.hook_called = True
if isinstance(data, CustomBatch):
data.samples = data.samples.to(device)
data.targets = data.targets.to(device)
else:
data = super().transfer_batch_to_device(data, device)
return data
model = CurrentTestModel()
batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long)))
trainer = Trainer()
# running .fit() would require us to implement custom data loaders, we mock the model reference instead
trainer.get_model = MagicMock(return_value=model)
batch_gpu = trainer.transfer_batch_to_gpu(batch, 0)
expected = torch.device('cuda', 0)
assert model.hook_called
assert batch_gpu.samples.device == batch_gpu.targets.device == expected