diff --git a/CHANGELOG.md b/CHANGELOG.md index f0d53d1a73..64e475739e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Speed up single-core TPU training by loading data using `ParallelLoader` ([#2033](https://github.com/PyTorchLightning/pytorch-lightning/pull/2033)) +- Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756)). + ### Changed - Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729)) diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index 3a6ffc1d7f..74876ca471 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -60,8 +60,8 @@ else: 'Trainer', 'LightningModule', 'Callback', - 'data_loader' - 'seed_everything' + 'data_loader', + 'seed_everything', ] # necessary for regular bolts imports. Skip exception since bolts is not always installed diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 1a3f05be11..960c712438 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -3,6 +3,8 @@ from typing import Any import torch from torch import Tensor from torch.optim.optimizer import Optimizer +from pytorch_lightning.utilities import move_data_to_device + try: from apex import amp @@ -153,3 +155,48 @@ class ModelHooks(torch.nn.Module): scaled_loss.backward() else: loss.backward() + + def transfer_batch_to_device(self, batch: Any, device: torch.device) -> Any: + """ + Override this hook if your :class:`~torch.utils.data.DataLoader` returns tensors + wrapped in a custom data structure. + + The data types listed below (and any arbitrary nesting of them) are supported out of the box: + + - :class:`torch.Tensor` + - :class:`list` + - :class:`dict` + - :class:`tuple` + - ``torchtext.data.Batch`` (COMING SOON) + + For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...). + + Example:: + + def transfer_batch_to_device(self, batch, device) + if isinstance(batch, CustomBatch): + # move all tensors in your custom data structure to the device + batch.samples = batch.samples.to(device) + batch.targets = batch.targets.to(device) + else: + batch = super().transfer_batch_to_device(data, device) + return batch + + Args: + batch: A batch of data that needs to be transferred to a new device. + device: The target device as defined in PyTorch. + + Returns: + A reference to the data on the new device. + + Note: + This hook should only transfer the data and not modify it, nor should it move the data to + any other device than the one passed in as argument (unless you know what you are doing). + The :class:`~pytorch_lightning.trainer.trainer.Trainer` already takes care of splitting the + batch and determines the target devices. + + See Also: + - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` + - :func:`~pytorch_lightning.utilities.apply_func.apply_to_collection` + """ + return move_data_to_device(batch, device) diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index c2103d40e6..9ea54a0e00 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -18,6 +18,7 @@ from pytorch_lightning.overrides.data_parallel import ( LightningDistributedDataParallel, LightningDataParallel, ) +from pytorch_lightning.utilities import move_data_to_device from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.distributed import rank_zero_only @@ -99,58 +100,50 @@ class TrainerDPMixin(ABC): m.tpu_local_core_rank = self.tpu_local_core_rank m.tpu_global_core_rank = self.tpu_global_core_rank - def transfer_batch_to_tpu(self, batch): - return self.__transfer_data_to_device(batch, device='tpu') + def transfer_batch_to_tpu(self, batch: Any, tpu_id: Optional[int] = None): + """ + Transfers the data to the TPU. - def transfer_batch_to_gpu(self, batch, gpu_id): - return self.__transfer_data_to_device(batch, device='gpu', gpu_id=gpu_id) + Args: + batch: A tensor or collection of tensors. + tpu_id: The id of the TPU core. If omitted, the first available core is chosen. - def __transfer_data_to_device(self, batch, device, gpu_id=None): - if device == 'tpu' and XLA_AVAILABLE: - # base case: object can be directly moved using `to` - if callable(getattr(batch, 'to', None)): - xla_device = xm.xla_device(self.tpu_id) if self.tpu_id is not None else xm.xla_device() - return batch.to(xla_device) + Return: + the tensor on the TPU device. - if device == 'gpu': - # base case: object can be directly moved using `cuda` or `to` - if callable(getattr(batch, 'cuda', None)): - # non_blocking will be ignored if tensor is not pinned. - # so we can always set it to True - return batch.cuda(gpu_id, non_blocking=True) + See Also: + - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` + """ + if not XLA_AVAILABLE: + raise MisconfigurationException( + 'Requested to transfer batch to TPU but XLA is not available.' + ' Are you sure this machine has TPUs?' + ) + device = xm.xla_device(tpu_id) + return self.__transfer_batch_to_device(batch, device) - if callable(getattr(batch, 'to', None)): - # non_blocking will be ignored if tensor is not pinned. - # so we can always set it to True - return batch.to(torch.device('cuda', gpu_id), non_blocking=True) + def transfer_batch_to_gpu(self, batch: Any, gpu_id: Optional[int] = None): + """ + Transfers the data to the GPU. - # when list - if isinstance(batch, list): - for i, x in enumerate(batch): - batch[i] = self.__transfer_data_to_device(x, device, gpu_id) - return batch + Args: + batch: A tensor or collection of tensors. + gpu_id: The id of the GPU device. If omitted, the first available GPU is chosen. - # when tuple - if isinstance(batch, tuple): - # when namedtuple - if hasattr(batch, '_fields'): - elem_type = type(batch) - return elem_type(*(self.__transfer_data_to_device(x, device, gpu_id) for x in batch)) - else: - batch = list(batch) - for i, x in enumerate(batch): - batch[i] = self.__transfer_data_to_device(x, device, gpu_id) - return tuple(batch) + Return: + the tensor on the GPU device. - # when dict - if isinstance(batch, dict): - for k, v in batch.items(): - batch[k] = self.__transfer_data_to_device(v, device, gpu_id) + See Also: + - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` + """ + device = torch.device('cuda', gpu_id) + return self.__transfer_batch_to_device(batch, device) - return batch - - # nothing matches, return the value as is without transform - return batch + def __transfer_batch_to_device(self, batch: Any, device: torch.device): + model = self.get_model() + if model is not None: + return model.transfer_batch_to_device(batch, device) + return move_data_to_device(batch, device) def single_gpu_train(self, model): model.cuda(self.root_gpu) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index a8c866f990..0676e28bcd 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -434,7 +434,7 @@ class TrainerEvaluationLoopMixin(ABC): # TPU data transfer if self.use_tpu: - batch = self.transfer_batch_to_tpu(batch) + batch = self.transfer_batch_to_tpu(batch, self.tpu_id) args[0] = batch # CPU, TPU or gpu step diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 4961e58093..b7286797e2 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -753,7 +753,7 @@ class TrainerTrainLoopMixin(ABC): # TPU support elif self.use_tpu: - batch = self.transfer_batch_to_tpu(batch) + batch = self.transfer_batch_to_tpu(batch, self.tpu_id) args[0] = batch output = self.model.training_step(*args) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index c8bc280523..51eb3b283d 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -1,3 +1,4 @@ """General utilities""" from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.apply_func import move_data_to_device diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 724715c3d8..bb32f79df9 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -1,6 +1,8 @@ from collections import Mapping, Sequence from typing import Any, Callable, Union +import torch + def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any: """ @@ -34,3 +36,24 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable # data is neither of dtype, nor a collection return data + + +def move_data_to_device(batch: Any, device: torch.device): + """ + Transfers a collection of tensors to the given device. + + Args: + batch: A tensor or collection of tensors. See :func:`apply_to_collection` + for a list of supported collection types. + device: The device to which tensors should be moved + + Return: + the same collection but with all contained tensors residing on the new device. + + See Also: + - :meth:`torch.Tensor.to` + - :class:`torch.device` + """ + def to(tensor): + return tensor.to(device, non_blocking=True) + return apply_to_collection(batch, dtype=torch.Tensor, function=to) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 568a8eae43..47b73eb9e7 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -1,3 +1,5 @@ +from unittest.mock import MagicMock + import pytest import torch @@ -68,3 +70,37 @@ def test_training_epoch_end_metrics_collection(tmpdir): # metrics are kept after each epoch for i in range(num_epochs): assert metrics[f'epoch_metric_{i}'] == i + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_transfer_batch_hook(): + + class CustomBatch: + + def __init__(self, data): + self.samples = data[0] + self.targets = data[1] + + class CurrentTestModel(EvalModelTemplate): + + hook_called = False + + def transfer_batch_to_device(self, data, device): + self.hook_called = True + if isinstance(data, CustomBatch): + data.samples = data.samples.to(device) + data.targets = data.targets.to(device) + else: + data = super().transfer_batch_to_device(data, device) + return data + + model = CurrentTestModel() + batch = CustomBatch((torch.zeros(5, 28), torch.ones(5, 1, dtype=torch.long))) + + trainer = Trainer() + # running .fit() would require us to implement custom data loaders, we mock the model reference instead + trainer.get_model = MagicMock(return_value=model) + batch_gpu = trainer.transfer_batch_to_gpu(batch, 0) + expected = torch.device('cuda', 0) + assert model.hook_called + assert batch_gpu.samples.device == batch_gpu.targets.device == expected