Support torchtext on a single GPU (#2379)
* Handle torchtext.data.Batch on GPU * Update CHANGELOG.md * Apply code review requests * Correct the docs * Change requirements
This commit is contained in:
parent
73a78a13c7
commit
e82d9cdb66
|
@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
- Added TorchText support for moving data to GPU ([#2379](https://github.com/PyTorchLightning/pytorch-lightning/pull/2379))
|
||||
|
||||
### Changed
|
||||
|
||||
- Changed epoch indexing from 0 instead of 1 ([#2289](https://github.com/PyTorchLightning/pytorch-lightning/pull/2289))
|
||||
|
|
|
@ -208,7 +208,7 @@ class ModelHooks(Module):
|
|||
- :class:`list`
|
||||
- :class:`dict`
|
||||
- :class:`tuple`
|
||||
- ``torchtext.data.Batch`` (COMING SOON)
|
||||
- :class:`torchtext.data.batch.Batch`
|
||||
|
||||
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, ...).
|
||||
|
||||
|
|
|
@ -3,6 +3,8 @@ from collections import Mapping, Sequence
|
|||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
from torchtext.data import Batch
|
||||
from copy import copy
|
||||
|
||||
|
||||
def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any:
|
||||
|
@ -84,6 +86,16 @@ def move_data_to_device(batch: Any, device: torch.device):
|
|||
- :meth:`torch.Tensor.to`
|
||||
- :class:`torch.device`
|
||||
"""
|
||||
def to(data):
|
||||
|
||||
def batch_to(data):
|
||||
if isinstance(data, Batch):
|
||||
# Shallow copy because each Batch has a reference to Dataset which contains all examples
|
||||
device_data = copy(data)
|
||||
for field in data.fields:
|
||||
# Batch contains output of Field.process(...) which is tensor hence .to(...) exists
|
||||
device_field = getattr(data, field).to(device, non_blocking=True)
|
||||
setattr(device_data, field, device_field)
|
||||
return device_data
|
||||
|
||||
return data.to(device, non_blocking=True)
|
||||
return apply_to_collection(batch, dtype=TransferableDataType, function=to)
|
||||
return apply_to_collection(batch, dtype=(TransferableDataType, Batch), function=batch_to)
|
||||
|
|
|
@ -7,3 +7,4 @@ tensorboard>=1.14
|
|||
future>=0.17.1 # required for builtins in setup.py
|
||||
# pyyaml>=3.13
|
||||
PyYAML>=5.1 # OmegaConf requirement
|
||||
torchtext>=0.3.1
|
|
@ -10,4 +10,4 @@ matplotlib>=3.1.1
|
|||
horovod>=0.19.1
|
||||
omegaconf>=2.0.0
|
||||
# scipy>=0.13.3
|
||||
scikit-learn>=0.20.0
|
||||
scikit-learn>=0.20.0
|
|
@ -10,6 +10,7 @@ from pytorch_lightning.core import memory
|
|||
from pytorch_lightning.trainer.distrib_parts import _parse_gpu_ids, determine_root_gpu_device
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import EvalModelTemplate
|
||||
from torchtext.data import Batch, Dataset, Example, Field, LabelField
|
||||
|
||||
PRETEND_N_OF_GPUS = 16
|
||||
|
||||
|
@ -301,3 +302,32 @@ def test_single_gpu_batch_parse():
|
|||
|
||||
batch = trainer.transfer_batch_to_gpu(CustomBatchType())
|
||||
assert batch.a.type() == 'torch.cuda.FloatTensor'
|
||||
|
||||
# torchtext.data.Batch
|
||||
samples = [
|
||||
{'text': 'PyTorch Lightning is awesome!', 'label': 0},
|
||||
{'text': 'Please make it work with torchtext', 'label': 1}
|
||||
]
|
||||
|
||||
text_field = Field()
|
||||
label_field = LabelField()
|
||||
fields = {
|
||||
'text': ('text', text_field),
|
||||
'label': ('label', label_field)
|
||||
}
|
||||
|
||||
examples = [Example.fromdict(sample, fields) for sample in samples]
|
||||
dataset = Dataset(
|
||||
examples=examples,
|
||||
fields=fields.values()
|
||||
)
|
||||
|
||||
# Batch runs field.process() that numericalizes tokens, but it requires to build dictionary first
|
||||
text_field.build_vocab(dataset)
|
||||
label_field.build_vocab(dataset)
|
||||
|
||||
batch = Batch(data=examples, dataset=dataset)
|
||||
batch = trainer.transfer_batch_to_gpu(batch, 0)
|
||||
|
||||
assert batch.text.type() == 'torch.cuda.LongTensor'
|
||||
assert batch.label.type() == 'torch.cuda.LongTensor'
|
||||
|
|
Loading…
Reference in New Issue