From e82d9cdb66fea9976c923b1619030cbeea202947 Mon Sep 17 00:00:00 2001 From: Mateusz Pieniak <31375424+mateuszpieniak@users.noreply.github.com> Date: Sat, 27 Jun 2020 22:36:45 +0200 Subject: [PATCH] Support torchtext on a single GPU (#2379) * Handle torchtext.data.Batch on GPU * Update CHANGELOG.md * Apply code review requests * Correct the docs * Change requirements --- CHANGELOG.md | 2 ++ pytorch_lightning/core/hooks.py | 2 +- pytorch_lightning/utilities/apply_func.py | 16 ++++++++++-- requirements/base.txt | 1 + requirements/extra.txt | 2 +- tests/models/test_gpu.py | 30 +++++++++++++++++++++++ 6 files changed, 49 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f89afd634..617bbcb749 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added TorchText support for moving data to GPU ([#2379](https://github.com/PyTorchLightning/pytorch-lightning/pull/2379)) + ### Changed - Changed epoch indexing from 0 instead of 1 ([#2289](https://github.com/PyTorchLightning/pytorch-lightning/pull/2289)) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 0c5215f772..9c329def39 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -208,7 +208,7 @@ class ModelHooks(Module): - :class:`list` - :class:`dict` - :class:`tuple` - - ``torchtext.data.Batch`` (COMING SOON) + - :class:`torchtext.data.batch.Batch` For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...). diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 6e0b530e8f..b5ec664b2a 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -3,6 +3,8 @@ from collections import Mapping, Sequence from typing import Any, Callable, Union import torch +from torchtext.data import Batch +from copy import copy def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any: @@ -84,6 +86,16 @@ def move_data_to_device(batch: Any, device: torch.device): - :meth:`torch.Tensor.to` - :class:`torch.device` """ - def to(data): + + def batch_to(data): + if isinstance(data, Batch): + # Shallow copy because each Batch has a reference to Dataset which contains all examples + device_data = copy(data) + for field in data.fields: + # Batch contains output of Field.process(...) which is tensor hence .to(...) exists + device_field = getattr(data, field).to(device, non_blocking=True) + setattr(device_data, field, device_field) + return device_data + return data.to(device, non_blocking=True) - return apply_to_collection(batch, dtype=TransferableDataType, function=to) + return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to) diff --git a/requirements/base.txt b/requirements/base.txt index 0f173c0f0a..e045fab912 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -7,3 +7,4 @@ tensorboard>=1.14 future>=0.17.1 # required for builtins in setup.py # pyyaml>=3.13 PyYAML>=5.1 # OmegaConf requirement +torchtext>=0.3.1 \ No newline at end of file diff --git a/requirements/extra.txt b/requirements/extra.txt index 68990e71a3..a9d8b6bdf4 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -10,4 +10,4 @@ matplotlib>=3.1.1 horovod>=0.19.1 omegaconf>=2.0.0 # scipy>=0.13.3 -scikit-learn>=0.20.0 +scikit-learn>=0.20.0 \ No newline at end of file diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 382866b625..734478f26a 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -10,6 +10,7 @@ from pytorch_lightning.core import memory from pytorch_lightning.trainer.distrib_parts import _parse_gpu_ids, determine_root_gpu_device from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate +from torchtext.data import Batch, Dataset, Example, Field, LabelField PRETEND_N_OF_GPUS = 16 @@ -301,3 +302,32 @@ def test_single_gpu_batch_parse(): batch = trainer.transfer_batch_to_gpu(CustomBatchType()) assert batch.a.type() == 'torch.cuda.FloatTensor' + + # torchtext.data.Batch + samples = [ + {'text': 'PyTorch Lightning is awesome!', 'label': 0}, + {'text': 'Please make it work with torchtext', 'label': 1} + ] + + text_field = Field() + label_field = LabelField() + fields = { + 'text': ('text', text_field), + 'label': ('label', label_field) + } + + examples = [Example.fromdict(sample, fields) for sample in samples] + dataset = Dataset( + examples=examples, + fields=fields.values() + ) + + # Batch runs field.process() that numericalizes tokens, but it requires to build dictionary first + text_field.build_vocab(dataset) + label_field.build_vocab(dataset) + + batch = Batch(data=examples, dataset=dataset) + batch = trainer.transfer_batch_to_gpu(batch, 0) + + assert batch.text.type() == 'torch.cuda.LongTensor' + assert batch.label.type() == 'torch.cuda.LongTensor'