Fix mypy typing for `utilities.apply_func` (#8781)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
Daniel Stancl 2021-08-26 18:36:22 +02:00 committed by GitHub
parent dfffb94b3c
commit 53885afc2e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 23 deletions

View File

@ -65,6 +65,7 @@ module = [
"pytorch_lightning.loops.closure", "pytorch_lightning.loops.closure",
"pytorch_lightning.trainer.evaluation_loop", "pytorch_lightning.trainer.evaluation_loop",
"pytorch_lightning.trainer.connectors.logger_connector", "pytorch_lightning.trainer.connectors.logger_connector",
"pytorch_lightning.utilities.apply_func",
"pytorch_lightning.utilities.argparse", "pytorch_lightning.utilities.argparse",
"pytorch_lightning.utilities.cli", "pytorch_lightning.utilities.cli",
"pytorch_lightning.utilities.cloud_io", "pytorch_lightning.utilities.cloud_io",

View File

@ -18,7 +18,7 @@ from collections import OrderedDict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from copy import copy from copy import copy
from functools import partial from functools import partial
from typing import Any, Callable, Optional, Union from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -35,19 +35,23 @@ else:
Batch = type(None) Batch = type(None)
def to_dtype_tensor(value, dtype: torch.dtype = None, device: torch.device = None): def to_dtype_tensor(
value: Union[int, float, List[Union[int, float]]],
dtype: Optional[torch.dtype] = None,
device: Union[str, torch.device] = None,
) -> torch.Tensor:
if device is None: if device is None:
raise MisconfigurationException("device (torch.device) should be provided.") raise MisconfigurationException("device (torch.device) should be provided.")
return torch.tensor(value, dtype=dtype, device=device) return torch.tensor(value, dtype=dtype, device=device)
def from_numpy(value, device: torch.device = None): def from_numpy(value: np.ndarray, device: Union[str, torch.device] = None) -> torch.Tensor:
if device is None: if device is None:
raise MisconfigurationException("device (torch.device) should be provided.") raise MisconfigurationException("device (torch.device) should be provided.")
return torch.from_numpy(value).to(device) return torch.from_numpy(value).to(device)
CONVERSION_DTYPES = [ CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any], torch.Tensor]]] = [
# bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group # bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group
(bool, partial(to_dtype_tensor, dtype=torch.uint8)), (bool, partial(to_dtype_tensor, dtype=torch.uint8)),
(int, partial(to_dtype_tensor, dtype=torch.int)), (int, partial(to_dtype_tensor, dtype=torch.int)),
@ -61,19 +65,19 @@ def _is_namedtuple(obj: object) -> bool:
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
def _is_dataclass_instance(obj): def _is_dataclass_instance(obj: object) -> bool:
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions # https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
return dataclasses.is_dataclass(obj) and not isinstance(obj, type) return dataclasses.is_dataclass(obj) and not isinstance(obj, type)
def apply_to_collection( def apply_to_collection(
data: Any, data: Any,
dtype: Union[type, tuple], dtype: Union[type, Any, Tuple[Union[type, Any]]],
function: Callable, function: Callable,
*args, *args: Any,
wrong_dtype: Optional[Union[type, tuple]] = None, wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
include_none: bool = True, include_none: bool = True,
**kwargs **kwargs: Any,
) -> Any: ) -> Any:
""" """
Recursively applies a function to all elements of a certain dtype. Recursively applies a function to all elements of a certain dtype.
@ -121,7 +125,7 @@ def apply_to_collection(
return elem_type(*out) if is_namedtuple else elem_type(out) return elem_type(*out) if is_namedtuple else elem_type(out)
if _is_dataclass_instance(data): if _is_dataclass_instance(data):
out = {} out_dict = {}
for field in data.__dataclass_fields__: for field in data.__dataclass_fields__:
v = apply_to_collection( v = apply_to_collection(
getattr(data, field), getattr(data, field),
@ -130,11 +134,11 @@ def apply_to_collection(
*args, *args,
wrong_dtype=wrong_dtype, wrong_dtype=wrong_dtype,
include_none=include_none, include_none=include_none,
**kwargs **kwargs,
) )
if include_none or v is not None: if include_none or v is not None:
out[field] = v out_dict[field] = v
return elem_type(**out) return elem_type(**out_dict)
# data is neither of dtype, nor a collection # data is neither of dtype, nor a collection
return data return data
@ -143,11 +147,11 @@ def apply_to_collection(
def apply_to_collections( def apply_to_collections(
data1: Optional[Any], data1: Optional[Any],
data2: Optional[Any], data2: Optional[Any],
dtype: Union[type, tuple], dtype: Union[type, Any, Tuple[Union[type, Any]]],
function: Callable, function: Callable,
*args, *args: Any,
wrong_dtype: Optional[Union[type, tuple]] = None, wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
**kwargs **kwargs: Any,
) -> Any: ) -> Any:
""" """
Zips two collections and applies a function to their items of a certain dtype. Zips two collections and applies a function to their items of a certain dtype.
@ -169,7 +173,9 @@ def apply_to_collections(
AssertionError: AssertionError:
If sequence collections have different data sizes. If sequence collections have different data sizes.
""" """
if data1 is None and data2 is not None: if data1 is None:
if data2 is None:
return
# in case they were passed reversed # in case they were passed reversed
data1, data2 = data2, None data1, data2 = data2, None
@ -220,14 +226,14 @@ class TransferableDataType(ABC):
""" """
@classmethod @classmethod
def __subclasshook__(cls, subclass): def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is TransferableDataType: if cls is TransferableDataType:
to = getattr(subclass, "to", None) to = getattr(subclass, "to", None)
return callable(to) return callable(to)
return NotImplemented return NotImplemented
def move_data_to_device(batch: Any, device: torch.device): def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
""" """
Transfers a collection of data to the given device. Any object that defines a method Transfers a collection of data to the given device. Any object that defines a method
``to(device)`` will be moved and all other objects in the collection will be left untouched. ``to(device)`` will be moved and all other objects in the collection will be left untouched.
@ -245,7 +251,7 @@ def move_data_to_device(batch: Any, device: torch.device):
- :class:`torch.device` - :class:`torch.device`
""" """
def batch_to(data): def batch_to(data: Any) -> Any:
# try to move torchtext data first # try to move torchtext data first
if _TORCHTEXT_AVAILABLE and isinstance(data, Batch): if _TORCHTEXT_AVAILABLE and isinstance(data, Batch):
@ -269,14 +275,14 @@ def move_data_to_device(batch: Any, device: torch.device):
return apply_to_collection(batch, dtype=dtype, function=batch_to) return apply_to_collection(batch, dtype=dtype, function=batch_to)
def convert_to_tensors(data: Any, device: torch.device) -> Any: def convert_to_tensors(data: Any, device: Union[str, torch.device]) -> Any:
if device is None: if device is None:
raise MisconfigurationException("`torch.device` should be provided.") raise MisconfigurationException("`torch.device` should be provided.")
for src_dtype, conversion_func in CONVERSION_DTYPES: for src_dtype, conversion_func in CONVERSION_DTYPES:
data = apply_to_collection(data, src_dtype, conversion_func, device=device) data = apply_to_collection(data, src_dtype, conversion_func, device=device)
def _move_to_device_and_make_contiguous(t: torch.Tensor, device: torch.device) -> torch.Tensor: def _move_to_device_and_make_contiguous(t: torch.Tensor, device: Union[str, torch.device]) -> torch.Tensor:
return t.to(device).contiguous() return t.to(device).contiguous()
data = apply_to_collection(data, torch.Tensor, _move_to_device_and_make_contiguous, device=device) data = apply_to_collection(data, torch.Tensor, _move_to_device_and_make_contiguous, device=device)