diff --git a/CHANGELOG.md b/CHANGELOG.md index 26dcb32ad8..6871b7d7ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319)) +- Fixed lost compatibility with custom datatypes implementing `.to` ([#2335](https://github.com/PyTorchLightning/pytorch-lightning/pull/2335)) + ## [0.8.1] - 2020-06-19 ### Fixed diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 498a3ef038..46a92433c5 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -204,7 +204,7 @@ class ModelHooks(Module): The data types listed below (and any arbitrary nesting of them) are supported out of the box: - - :class:`torch.Tensor` + - :class:`torch.Tensor` or anything that implements `.to(...)` - :class:`list` - :class:`dict` - :class:`tuple` diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index bb32f79df9..6e0b530e8f 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -1,3 +1,4 @@ +from abc import ABC from collections import Mapping, Sequence from typing import Any, Callable, Union @@ -38,14 +39,43 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable return data +class TransferableDataType(ABC): + """ + A custom type for data that can be moved to a torch device via `.to(...)`. + + Example: + + >>> isinstance(dict, TransferableDataType) + False + >>> isinstance(torch.rand(2, 3), TransferableDataType) + True + >>> class CustomObject: + ... def __init__(self): + ... self.x = torch.rand(2, 2) + ... def to(self, device): + ... self.x = self.x.to(device) + ... return self + >>> isinstance(CustomObject(), TransferableDataType) + True + """ + + @classmethod + def __subclasshook__(cls, subclass): + if cls is TransferableDataType: + to = getattr(subclass, "to", None) + return callable(to) + return NotImplemented + + def move_data_to_device(batch: Any, device: torch.device): """ - Transfers a collection of tensors to the given device. + Transfers a collection of data to the given device. Any object that defines a method + ``to(device)`` will be moved and all other objects in the collection will be left untouched. Args: - batch: A tensor or collection of tensors. See :func:`apply_to_collection` - for a list of supported collection types. - device: The device to which tensors should be moved + batch: A tensor or collection of tensors or anything that has a method `.to(...)`. + See :func:`apply_to_collection` for a list of supported collection types. + device: The device to which the data should be moved Return: the same collection but with all contained tensors residing on the new device. @@ -54,6 +84,6 @@ def move_data_to_device(batch: Any, device: torch.device): - :meth:`torch.Tensor.to` - :class:`torch.device` """ - def to(tensor): - return tensor.to(device, non_blocking=True) - return apply_to_collection(batch, dtype=torch.Tensor, function=to) + def to(data): + return data.to(device, non_blocking=True) + return apply_to_collection(batch, dtype=TransferableDataType, function=to) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 65c44c0e2d..3db39e0b03 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -286,3 +286,15 @@ def test_single_gpu_batch_parse(): batch = trainer.transfer_batch_to_gpu(batch, 0) assert batch[0].a.device.index == 0 assert batch[0].a.type() == 'torch.cuda.FloatTensor' + + # non-Tensor that has `.to()` defined + class CustomBatchType: + def __init__(self): + self.a = torch.rand(2, 2) + + def to(self, *args, **kwargs): + self.a = self.a.to(*args, **kwargs) + return self + + batch = trainer.transfer_batch_to_gpu(CustomBatchType()) + assert batch.a.type() == 'torch.cuda.FloatTensor'