Fix mypy typing for `utilities.apply_func` (#8781)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
Daniel Stancl 2021-08-26 18:36:22 +02:00 committed by GitHub
parent dfffb94b3c
commit 53885afc2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 23 deletions

View File

@ -65,6 +65,7 @@ module = [
"pytorch_lightning.loops.closure",
"pytorch_lightning.trainer.evaluation_loop",
"pytorch_lightning.trainer.connectors.logger_connector",
"pytorch_lightning.utilities.apply_func",
"pytorch_lightning.utilities.argparse",
"pytorch_lightning.utilities.cli",
"pytorch_lightning.utilities.cloud_io",

View File

@ -18,7 +18,7 @@ from collections import OrderedDict
from collections.abc import Mapping, Sequence
from copy import copy
from functools import partial
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np
import torch
@ -35,19 +35,23 @@ else:
Batch = type(None)
def to_dtype_tensor(value, dtype: torch.dtype = None, device: torch.device = None):
def to_dtype_tensor(
value: Union[int, float, List[Union[int, float]]],
dtype: Optional[torch.dtype] = None,
device: Union[str, torch.device] = None,
) -> torch.Tensor:
if device is None:
raise MisconfigurationException("device (torch.device) should be provided.")
return torch.tensor(value, dtype=dtype, device=device)
def from_numpy(value, device: torch.device = None):
def from_numpy(value: np.ndarray, device: Union[str, torch.device] = None) -> torch.Tensor:
if device is None:
raise MisconfigurationException("device (torch.device) should be provided.")
return torch.from_numpy(value).to(device)
CONVERSION_DTYPES = [
CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any], torch.Tensor]]] = [
# bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group
(bool, partial(to_dtype_tensor, dtype=torch.uint8)),
(int, partial(to_dtype_tensor, dtype=torch.int)),
@ -61,19 +65,19 @@ def _is_namedtuple(obj: object) -> bool:
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
def _is_dataclass_instance(obj):
def _is_dataclass_instance(obj: object) -> bool:
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)
def apply_to_collection(
data: Any,
dtype: Union[type, tuple],
dtype: Union[type, Any, Tuple[Union[type, Any]]],
function: Callable,
*args,
wrong_dtype: Optional[Union[type, tuple]] = None,
*args: Any,
wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
include_none: bool = True,
**kwargs
**kwargs: Any,
) -> Any:
"""
Recursively applies a function to all elements of a certain dtype.
@ -121,7 +125,7 @@ def apply_to_collection(
return elem_type(*out) if is_namedtuple else elem_type(out)
if _is_dataclass_instance(data):
out = {}
out_dict = {}
for field in data.__dataclass_fields__:
v = apply_to_collection(
getattr(data, field),
@ -130,11 +134,11 @@ def apply_to_collection(
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
**kwargs
**kwargs,
)
if include_none or v is not None:
out[field] = v
return elem_type(**out)
out_dict[field] = v
return elem_type(**out_dict)
# data is neither of dtype, nor a collection
return data
@ -143,11 +147,11 @@ def apply_to_collection(
def apply_to_collections(
data1: Optional[Any],
data2: Optional[Any],
dtype: Union[type, tuple],
dtype: Union[type, Any, Tuple[Union[type, Any]]],
function: Callable,
*args,
wrong_dtype: Optional[Union[type, tuple]] = None,
**kwargs
*args: Any,
wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
**kwargs: Any,
) -> Any:
"""
Zips two collections and applies a function to their items of a certain dtype.
@ -169,7 +173,9 @@ def apply_to_collections(
AssertionError:
If sequence collections have different data sizes.
"""
if data1 is None and data2 is not None:
if data1 is None:
if data2 is None:
return
# in case they were passed reversed
data1, data2 = data2, None
@ -220,14 +226,14 @@ class TransferableDataType(ABC):
"""
@classmethod
def __subclasshook__(cls, subclass):
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is TransferableDataType:
to = getattr(subclass, "to", None)
return callable(to)
return NotImplemented
def move_data_to_device(batch: Any, device: torch.device):
def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
"""
Transfers a collection of data to the given device. Any object that defines a method
``to(device)`` will be moved and all other objects in the collection will be left untouched.
@ -245,7 +251,7 @@ def move_data_to_device(batch: Any, device: torch.device):
- :class:`torch.device`
"""
def batch_to(data):
def batch_to(data: Any) -> Any:
# try to move torchtext data first
if _TORCHTEXT_AVAILABLE and isinstance(data, Batch):
@ -269,14 +275,14 @@ def move_data_to_device(batch: Any, device: torch.device):
return apply_to_collection(batch, dtype=dtype, function=batch_to)
def convert_to_tensors(data: Any, device: torch.device) -> Any:
def convert_to_tensors(data: Any, device: Union[str, torch.device]) -> Any:
if device is None:
raise MisconfigurationException("`torch.device` should be provided.")
for src_dtype, conversion_func in CONVERSION_DTYPES:
data = apply_to_collection(data, src_dtype, conversion_func, device=device)
def _move_to_device_and_make_contiguous(t: torch.Tensor, device: torch.device) -> torch.Tensor:
def _move_to_device_and_make_contiguous(t: torch.Tensor, device: Union[str, torch.device]) -> torch.Tensor:
return t.to(device).contiguous()
data = apply_to_collection(data, torch.Tensor, _move_to_device_and_make_contiguous, device=device)