60 lines
2.1 KiB
Python
60 lines
2.1 KiB
Python
from collections import Mapping, Sequence
|
|
from typing import Any, Callable, Union
|
|
|
|
import torch
|
|
|
|
|
|
def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, **kwargs) -> Any:
|
|
"""
|
|
Recursively applies a function to all elements of a certain dtype.
|
|
|
|
Args:
|
|
data: the collection to apply the function to
|
|
dtype: the given function will be applied to all elements of this dtype
|
|
function: the function to apply
|
|
*args: positional arguments (will be forwarded to calls of ``function``)
|
|
**kwargs: keyword arguments (will be forwarded to calls of ``function``)
|
|
|
|
Returns:
|
|
the resulting collection
|
|
|
|
"""
|
|
elem_type = type(data)
|
|
|
|
# Breaking condition
|
|
if isinstance(data, dtype):
|
|
return function(data, *args, **kwargs)
|
|
|
|
# Recursively apply to collection items
|
|
elif isinstance(data, Mapping):
|
|
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs)
|
|
for k, v in data.items()})
|
|
elif isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple
|
|
return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data))
|
|
elif isinstance(data, Sequence) and not isinstance(data, str):
|
|
return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data])
|
|
|
|
# data is neither of dtype, nor a collection
|
|
return data
|
|
|
|
|
|
def move_data_to_device(batch: Any, device: torch.device):
|
|
"""
|
|
Transfers a collection of tensors to the given device.
|
|
|
|
Args:
|
|
batch: A tensor or collection of tensors. See :func:`apply_to_collection`
|
|
for a list of supported collection types.
|
|
device: The device to which tensors should be moved
|
|
|
|
Return:
|
|
the same collection but with all contained tensors residing on the new device.
|
|
|
|
See Also:
|
|
- :meth:`torch.Tensor.to`
|
|
- :class:`torch.device`
|
|
"""
|
|
def to(tensor):
|
|
return tensor.to(device, non_blocking=True)
|
|
return apply_to_collection(batch, dtype=torch.Tensor, function=to)
|