Fix lost compatibility with custom datatypes implementing `.to` (#2335)
* generalize data transfer * added test * update docs * fix spelling error * changelog * update docs
This commit is contained in:
parent
598f5140c5
commit
aab9e77d2d
|
@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319))
|
||||
|
||||
- Fixed lost compatibility with custom datatypes implementing `.to` ([#2335](https://github.com/PyTorchLightning/pytorch-lightning/pull/2335))
|
||||
|
||||
## [0.8.1] - 2020-06-19
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -204,7 +204,7 @@ class ModelHooks(Module):
|
|||
|
||||
The data types listed below (and any arbitrary nesting of them) are supported out of the box:
|
||||
|
||||
- :class:`torch.Tensor`
|
||||
- :class:`torch.Tensor` or anything that implements `.to(...)`
|
||||
- :class:`list`
|
||||
- :class:`dict`
|
||||
- :class:`tuple`
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from abc import ABC
|
||||
from collections import Mapping, Sequence
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
|
@ -38,14 +39,43 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable
|
|||
return data
|
||||
|
||||
|
||||
class TransferableDataType(ABC):
|
||||
"""
|
||||
A custom type for data that can be moved to a torch device via `.to(...)`.
|
||||
|
||||
Example:
|
||||
|
||||
>>> isinstance(dict, TransferableDataType)
|
||||
False
|
||||
>>> isinstance(torch.rand(2, 3), TransferableDataType)
|
||||
True
|
||||
>>> class CustomObject:
|
||||
... def __init__(self):
|
||||
... self.x = torch.rand(2, 2)
|
||||
... def to(self, device):
|
||||
... self.x = self.x.to(device)
|
||||
... return self
|
||||
>>> isinstance(CustomObject(), TransferableDataType)
|
||||
True
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, subclass):
|
||||
if cls is TransferableDataType:
|
||||
to = getattr(subclass, "to", None)
|
||||
return callable(to)
|
||||
return NotImplemented
|
||||
|
||||
|
||||
def move_data_to_device(batch: Any, device: torch.device):
|
||||
"""
|
||||
Transfers a collection of tensors to the given device.
|
||||
Transfers a collection of data to the given device. Any object that defines a method
|
||||
``to(device)`` will be moved and all other objects in the collection will be left untouched.
|
||||
|
||||
Args:
|
||||
batch: A tensor or collection of tensors. See :func:`apply_to_collection`
|
||||
for a list of supported collection types.
|
||||
device: The device to which tensors should be moved
|
||||
batch: A tensor or collection of tensors or anything that has a method `.to(...)`.
|
||||
See :func:`apply_to_collection` for a list of supported collection types.
|
||||
device: The device to which the data should be moved
|
||||
|
||||
Return:
|
||||
the same collection but with all contained tensors residing on the new device.
|
||||
|
@ -54,6 +84,6 @@ def move_data_to_device(batch: Any, device: torch.device):
|
|||
- :meth:`torch.Tensor.to`
|
||||
- :class:`torch.device`
|
||||
"""
|
||||
def to(tensor):
|
||||
return tensor.to(device, non_blocking=True)
|
||||
return apply_to_collection(batch, dtype=torch.Tensor, function=to)
|
||||
def to(data):
|
||||
return data.to(device, non_blocking=True)
|
||||
return apply_to_collection(batch, dtype=TransferableDataType, function=to)
|
||||
|
|
|
@ -286,3 +286,15 @@ def test_single_gpu_batch_parse():
|
|||
batch = trainer.transfer_batch_to_gpu(batch, 0)
|
||||
assert batch[0].a.device.index == 0
|
||||
assert batch[0].a.type() == 'torch.cuda.FloatTensor'
|
||||
|
||||
# non-Tensor that has `.to()` defined
|
||||
class CustomBatchType:
|
||||
def __init__(self):
|
||||
self.a = torch.rand(2, 2)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
self.a = self.a.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
batch = trainer.transfer_batch_to_gpu(CustomBatchType())
|
||||
assert batch.a.type() == 'torch.cuda.FloatTensor'
|
||||
|
|
Loading…
Reference in New Issue