Fix mypy typing for `utilities.apply_func` (#8781)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka <jirka.borovec@seznam.cz> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
dfffb94b3c
commit
53885afc2e
|
@ -65,6 +65,7 @@ module = [
|
||||||
"pytorch_lightning.loops.closure",
|
"pytorch_lightning.loops.closure",
|
||||||
"pytorch_lightning.trainer.evaluation_loop",
|
"pytorch_lightning.trainer.evaluation_loop",
|
||||||
"pytorch_lightning.trainer.connectors.logger_connector",
|
"pytorch_lightning.trainer.connectors.logger_connector",
|
||||||
|
"pytorch_lightning.utilities.apply_func",
|
||||||
"pytorch_lightning.utilities.argparse",
|
"pytorch_lightning.utilities.argparse",
|
||||||
"pytorch_lightning.utilities.cli",
|
"pytorch_lightning.utilities.cli",
|
||||||
"pytorch_lightning.utilities.cloud_io",
|
"pytorch_lightning.utilities.cloud_io",
|
||||||
|
|
|
@ -18,7 +18,7 @@ from collections import OrderedDict
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -35,19 +35,23 @@ else:
|
||||||
Batch = type(None)
|
Batch = type(None)
|
||||||
|
|
||||||
|
|
||||||
def to_dtype_tensor(value, dtype: torch.dtype = None, device: torch.device = None):
|
def to_dtype_tensor(
|
||||||
|
value: Union[int, float, List[Union[int, float]]],
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Union[str, torch.device] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
if device is None:
|
if device is None:
|
||||||
raise MisconfigurationException("device (torch.device) should be provided.")
|
raise MisconfigurationException("device (torch.device) should be provided.")
|
||||||
return torch.tensor(value, dtype=dtype, device=device)
|
return torch.tensor(value, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
|
||||||
def from_numpy(value, device: torch.device = None):
|
def from_numpy(value: np.ndarray, device: Union[str, torch.device] = None) -> torch.Tensor:
|
||||||
if device is None:
|
if device is None:
|
||||||
raise MisconfigurationException("device (torch.device) should be provided.")
|
raise MisconfigurationException("device (torch.device) should be provided.")
|
||||||
return torch.from_numpy(value).to(device)
|
return torch.from_numpy(value).to(device)
|
||||||
|
|
||||||
|
|
||||||
CONVERSION_DTYPES = [
|
CONVERSION_DTYPES: List[Tuple[Any, Callable[[Any], torch.Tensor]]] = [
|
||||||
# bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group
|
# bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group
|
||||||
(bool, partial(to_dtype_tensor, dtype=torch.uint8)),
|
(bool, partial(to_dtype_tensor, dtype=torch.uint8)),
|
||||||
(int, partial(to_dtype_tensor, dtype=torch.int)),
|
(int, partial(to_dtype_tensor, dtype=torch.int)),
|
||||||
|
@ -61,19 +65,19 @@ def _is_namedtuple(obj: object) -> bool:
|
||||||
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
|
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
|
||||||
|
|
||||||
|
|
||||||
def _is_dataclass_instance(obj):
|
def _is_dataclass_instance(obj: object) -> bool:
|
||||||
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
|
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions
|
||||||
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)
|
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)
|
||||||
|
|
||||||
|
|
||||||
def apply_to_collection(
|
def apply_to_collection(
|
||||||
data: Any,
|
data: Any,
|
||||||
dtype: Union[type, tuple],
|
dtype: Union[type, Any, Tuple[Union[type, Any]]],
|
||||||
function: Callable,
|
function: Callable,
|
||||||
*args,
|
*args: Any,
|
||||||
wrong_dtype: Optional[Union[type, tuple]] = None,
|
wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
|
||||||
include_none: bool = True,
|
include_none: bool = True,
|
||||||
**kwargs
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Recursively applies a function to all elements of a certain dtype.
|
Recursively applies a function to all elements of a certain dtype.
|
||||||
|
@ -121,7 +125,7 @@ def apply_to_collection(
|
||||||
return elem_type(*out) if is_namedtuple else elem_type(out)
|
return elem_type(*out) if is_namedtuple else elem_type(out)
|
||||||
|
|
||||||
if _is_dataclass_instance(data):
|
if _is_dataclass_instance(data):
|
||||||
out = {}
|
out_dict = {}
|
||||||
for field in data.__dataclass_fields__:
|
for field in data.__dataclass_fields__:
|
||||||
v = apply_to_collection(
|
v = apply_to_collection(
|
||||||
getattr(data, field),
|
getattr(data, field),
|
||||||
|
@ -130,11 +134,11 @@ def apply_to_collection(
|
||||||
*args,
|
*args,
|
||||||
wrong_dtype=wrong_dtype,
|
wrong_dtype=wrong_dtype,
|
||||||
include_none=include_none,
|
include_none=include_none,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
if include_none or v is not None:
|
if include_none or v is not None:
|
||||||
out[field] = v
|
out_dict[field] = v
|
||||||
return elem_type(**out)
|
return elem_type(**out_dict)
|
||||||
|
|
||||||
# data is neither of dtype, nor a collection
|
# data is neither of dtype, nor a collection
|
||||||
return data
|
return data
|
||||||
|
@ -143,11 +147,11 @@ def apply_to_collection(
|
||||||
def apply_to_collections(
|
def apply_to_collections(
|
||||||
data1: Optional[Any],
|
data1: Optional[Any],
|
||||||
data2: Optional[Any],
|
data2: Optional[Any],
|
||||||
dtype: Union[type, tuple],
|
dtype: Union[type, Any, Tuple[Union[type, Any]]],
|
||||||
function: Callable,
|
function: Callable,
|
||||||
*args,
|
*args: Any,
|
||||||
wrong_dtype: Optional[Union[type, tuple]] = None,
|
wrong_dtype: Optional[Union[type, Tuple[type]]] = None,
|
||||||
**kwargs
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Zips two collections and applies a function to their items of a certain dtype.
|
Zips two collections and applies a function to their items of a certain dtype.
|
||||||
|
@ -169,7 +173,9 @@ def apply_to_collections(
|
||||||
AssertionError:
|
AssertionError:
|
||||||
If sequence collections have different data sizes.
|
If sequence collections have different data sizes.
|
||||||
"""
|
"""
|
||||||
if data1 is None and data2 is not None:
|
if data1 is None:
|
||||||
|
if data2 is None:
|
||||||
|
return
|
||||||
# in case they were passed reversed
|
# in case they were passed reversed
|
||||||
data1, data2 = data2, None
|
data1, data2 = data2, None
|
||||||
|
|
||||||
|
@ -220,14 +226,14 @@ class TransferableDataType(ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __subclasshook__(cls, subclass):
|
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
|
||||||
if cls is TransferableDataType:
|
if cls is TransferableDataType:
|
||||||
to = getattr(subclass, "to", None)
|
to = getattr(subclass, "to", None)
|
||||||
return callable(to)
|
return callable(to)
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
|
|
||||||
def move_data_to_device(batch: Any, device: torch.device):
|
def move_data_to_device(batch: Any, device: Union[str, torch.device]) -> Any:
|
||||||
"""
|
"""
|
||||||
Transfers a collection of data to the given device. Any object that defines a method
|
Transfers a collection of data to the given device. Any object that defines a method
|
||||||
``to(device)`` will be moved and all other objects in the collection will be left untouched.
|
``to(device)`` will be moved and all other objects in the collection will be left untouched.
|
||||||
|
@ -245,7 +251,7 @@ def move_data_to_device(batch: Any, device: torch.device):
|
||||||
- :class:`torch.device`
|
- :class:`torch.device`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def batch_to(data):
|
def batch_to(data: Any) -> Any:
|
||||||
# try to move torchtext data first
|
# try to move torchtext data first
|
||||||
if _TORCHTEXT_AVAILABLE and isinstance(data, Batch):
|
if _TORCHTEXT_AVAILABLE and isinstance(data, Batch):
|
||||||
|
|
||||||
|
@ -269,14 +275,14 @@ def move_data_to_device(batch: Any, device: torch.device):
|
||||||
return apply_to_collection(batch, dtype=dtype, function=batch_to)
|
return apply_to_collection(batch, dtype=dtype, function=batch_to)
|
||||||
|
|
||||||
|
|
||||||
def convert_to_tensors(data: Any, device: torch.device) -> Any:
|
def convert_to_tensors(data: Any, device: Union[str, torch.device]) -> Any:
|
||||||
if device is None:
|
if device is None:
|
||||||
raise MisconfigurationException("`torch.device` should be provided.")
|
raise MisconfigurationException("`torch.device` should be provided.")
|
||||||
|
|
||||||
for src_dtype, conversion_func in CONVERSION_DTYPES:
|
for src_dtype, conversion_func in CONVERSION_DTYPES:
|
||||||
data = apply_to_collection(data, src_dtype, conversion_func, device=device)
|
data = apply_to_collection(data, src_dtype, conversion_func, device=device)
|
||||||
|
|
||||||
def _move_to_device_and_make_contiguous(t: torch.Tensor, device: torch.device) -> torch.Tensor:
|
def _move_to_device_and_make_contiguous(t: torch.Tensor, device: Union[str, torch.device]) -> torch.Tensor:
|
||||||
return t.to(device).contiguous()
|
return t.to(device).contiguous()
|
||||||
|
|
||||||
data = apply_to_collection(data, torch.Tensor, _move_to_device_and_make_contiguous, device=device)
|
data = apply_to_collection(data, torch.Tensor, _move_to_device_and_make_contiguous, device=device)
|
||||||
|
|
Loading…
Reference in New Issue