Fix lost compatibility with custom datatypes implementing `.to` (#2335)

* generalize data transfer

* added test

* update docs

* fix spelling error

* changelog

* update docs
This commit is contained in:
Adrian Wälchli 2020-06-24 05:41:02 +02:00 committed by GitHub
parent 598f5140c5
commit aab9e77d2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 52 additions and 8 deletions

View File

@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319))
- Fixed lost compatibility with custom datatypes implementing `.to` ([#2335](https://github.com/PyTorchLightning/pytorch-lightning/pull/2335))
## [0.8.1] - 2020-06-19
### Fixed

View File

@ -204,7 +204,7 @@ class ModelHooks(Module):
The data types listed below (and any arbitrary nesting of them) are supported out of the box:
- :class:`torch.Tensor`
- :class:`torch.Tensor` or anything that implements `.to(...)`
- :class:`list`
- :class:`dict`
- :class:`tuple`

View File

@ -1,3 +1,4 @@
from abc import ABC
from collections import Mapping, Sequence
from typing import Any, Callable, Union
@ -38,14 +39,43 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable
return data
class TransferableDataType(ABC):
"""
A custom type for data that can be moved to a torch device via `.to(...)`.
Example:
>>> isinstance(dict, TransferableDataType)
False
>>> isinstance(torch.rand(2, 3), TransferableDataType)
True
>>> class CustomObject:
... def __init__(self):
... self.x = torch.rand(2, 2)
... def to(self, device):
... self.x = self.x.to(device)
... return self
>>> isinstance(CustomObject(), TransferableDataType)
True
"""
@classmethod
def __subclasshook__(cls, subclass):
if cls is TransferableDataType:
to = getattr(subclass, "to", None)
return callable(to)
return NotImplemented
def move_data_to_device(batch: Any, device: torch.device):
"""
Transfers a collection of tensors to the given device.
Transfers a collection of data to the given device. Any object that defines a method
``to(device)`` will be moved and all other objects in the collection will be left untouched.
Args:
batch: A tensor or collection of tensors. See :func:`apply_to_collection`
for a list of supported collection types.
device: The device to which tensors should be moved
batch: A tensor or collection of tensors or anything that has a method `.to(...)`.
See :func:`apply_to_collection` for a list of supported collection types.
device: The device to which the data should be moved
Return:
the same collection but with all contained tensors residing on the new device.
@ -54,6 +84,6 @@ def move_data_to_device(batch: Any, device: torch.device):
- :meth:`torch.Tensor.to`
- :class:`torch.device`
"""
def to(tensor):
return tensor.to(device, non_blocking=True)
return apply_to_collection(batch, dtype=torch.Tensor, function=to)
def to(data):
return data.to(device, non_blocking=True)
return apply_to_collection(batch, dtype=TransferableDataType, function=to)

View File

@ -286,3 +286,15 @@ def test_single_gpu_batch_parse():
batch = trainer.transfer_batch_to_gpu(batch, 0)
assert batch[0].a.device.index == 0
assert batch[0].a.type() == 'torch.cuda.FloatTensor'
# non-Tensor that has `.to()` defined
class CustomBatchType:
def __init__(self):
self.a = torch.rand(2, 2)
def to(self, *args, **kwargs):
self.a = self.a.to(*args, **kwargs)
return self
batch = trainer.transfer_batch_to_gpu(CustomBatchType())
assert batch.a.type() == 'torch.cuda.FloatTensor'