diff --git a/pyproject.toml b/pyproject.toml index 206a4717a1..e848464c9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ module = [ "pytorch_lightning.loops.closure", "pytorch_lightning.trainer.evaluation_loop", "pytorch_lightning.trainer.connectors.logger_connector", + "pytorch_lightning.utilities.apply_func", "pytorch_lightning.utilities.argparse", "pytorch_lightning.utilities.cli", "pytorch_lightning.utilities.cloud_io", diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index b96a0110e5..d7d09251f6 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -18,7 +18,7 @@ from collections import OrderedDict from collections.abc import Mapping, Sequence from copy import copy from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -35,19 +35,23 @@ else: Batch = type(None) -def to_dtype_tensor(value, dtype: torch.dtype = None, device: torch.device = None): +def to_dtype_tensor( + value: Union[int, float, List[Union[int, float]]], + dtype: Optional[torch.dtype] = None, + device: Union[str, torch.device] = None, +) -> torch.Tensor: if device is None: raise MisconfigurationException("device (torch.device) should be provided.") return torch.tensor(value, dtype=dtype, device=device) -def from_numpy(value, device: torch.device = None): +def from_numpy(value: np.ndarray, device: Union[str, torch.device] = None) -> torch.Tensor: if device is None: raise MisconfigurationException("device (torch.device) should be provided.") return torch.from_numpy(value).to(device) -CONVERSION_DTYPES = [ +CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any], torch.Tensor]]] = [ # bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group (bool, partial(to_dtype_tensor, dtype=torch.uint8)), (int, partial(to_dtype_tensor, dtype=torch.int)), @@ -61,19 +65,19 @@ def _is_namedtuple(obj: object) -> bool: return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") -def _is_dataclass_instance(obj): +def _is_dataclass_instance(obj: object) -> bool: # https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions return dataclasses.is_dataclass(obj) and not isinstance(obj, type) def apply_to_collection( data: Any, - dtype: Union[type, tuple], + dtype: Union[type, Any, Tuple[Union[type, Any]]], function: Callable, - *args, - wrong_dtype: Optional[Union[type, tuple]] = None, + *args: Any, + wrong_dtype: Optional[Union[type, Tuple[type]]] = None, include_none: bool = True, - **kwargs + **kwargs: Any, ) -> Any: """ Recursively applies a function to all elements of a certain dtype. @@ -121,7 +125,7 @@ def apply_to_collection( return elem_type(*out) if is_namedtuple else elem_type(out) if _is_dataclass_instance(data): - out = {} + out_dict = {} for field in data.__dataclass_fields__: v = apply_to_collection( getattr(data, field), @@ -130,11 +134,11 @@ def apply_to_collection( *args, wrong_dtype=wrong_dtype, include_none=include_none, - **kwargs + **kwargs, ) if include_none or v is not None: - out[field] = v - return elem_type(**out) + out_dict[field] = v + return elem_type(**out_dict) # data is neither of dtype, nor a collection return data @@ -143,11 +147,11 @@ def apply_to_collection( def apply_to_collections( data1: Optional[Any], data2: Optional[Any], - dtype: Union[type, tuple], + dtype: Union[type, Any, Tuple[Union[type, Any]]], function: Callable, - *args, - wrong_dtype: Optional[Union[type, tuple]] = None, - **kwargs + *args: Any, + wrong_dtype: Optional[Union[type, Tuple[type]]] = None, + **kwargs: Any, ) -> Any: """ Zips two collections and applies a function to their items of a certain dtype. @@ -169,7 +173,9 @@ def apply_to_collections( AssertionError: If sequence collections have different data sizes. """ - if data1 is None and data2 is not None: + if data1 is None: + if data2 is None: + return # in case they were passed reversed data1, data2 = data2, None @@ -220,14 +226,14 @@ class TransferableDataType(ABC): """ @classmethod - def __subclasshook__(cls, subclass): + def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: if cls is TransferableDataType: to = getattr(subclass, "to", None) return callable(to) return NotImplemented -def move_data_to_device(batch: Any, device: torch.device): +def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any: """ Transfers a collection of data to the given device. Any object that defines a method ``to(device)`` will be moved and all other objects in the collection will be left untouched. @@ -245,7 +251,7 @@ def move_data_to_device(batch: Any, device: torch.device): - :class:`torch.device` """ - def batch_to(data): + def batch_to(data: Any) -> Any: # try to move torchtext data first if _TORCHTEXT_AVAILABLE and isinstance(data, Batch): @@ -269,14 +275,14 @@ def move_data_to_device(batch: Any, device: torch.device): return apply_to_collection(batch, dtype=dtype, function=batch_to) -def convert_to_tensors(data: Any, device: torch.device) -> Any: +def convert_to_tensors(data: Any, device: Union[str, torch.device]) -> Any: if device is None: raise MisconfigurationException("`torch.device` should be provided.") for src_dtype, conversion_func in CONVERSION_DTYPES: data = apply_to_collection(data, src_dtype, conversion_func, device=device) - def _move_to_device_and_make_contiguous(t: torch.Tensor, device: torch.device) -> torch.Tensor: + def _move_to_device_and_make_contiguous(t: torch.Tensor, device: Union[str, torch.device]) -> torch.Tensor: return t.to(device).contiguous() data = apply_to_collection(data, torch.Tensor, _move_to_device_and_make_contiguous, device=device)