fix mypy typing errors in pytorch_lightning.utilities.data.py (#13901)

Co-authored-by: otaj <ota@lightning.ai>
This commit is contained in:
Ritik Nandwal 2022-09-14 16:51:57 +05:30 committed by GitHub
parent 9b01a0fd32
commit 8e9780bd5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 49 deletions

View File

@ -53,8 +53,6 @@ warn_no_return = "False"
module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.trainer.trainer",
"pytorch_lightning.tuner.batch_size_scaling",
"pytorch_lightning.utilities.data",
"lightning_lite.utilities.data",
"pytorch_lightning.tuner.batch_size_scaling"
]
ignore_errors = "True"

View File

@ -21,7 +21,7 @@ from functools import partial
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, Type, Union
from lightning_utilities.core.inheritance import get_all_subclasses
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler
from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Sampler
from lightning_lite.utilities.enums import LightningEnum
from lightning_lite.utilities.exceptions import MisconfigurationException
@ -33,7 +33,8 @@ class _WrapAttrTag(LightningEnum):
SET = "set"
DEL = "del"
def __call__(self, *args):
def __call__(self, *args: Any) -> None:
fn: Union[Callable[[object, str], None], Callable[[object, str, Any], None]]
if self == self.SET:
fn = setattr
else:
@ -45,12 +46,12 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool:
return hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset)
def has_len(dataloader: Union[DataLoader, Iterable]) -> bool:
def has_len(dataloader: Union[DataLoader, Iterable, Dataset]) -> bool:
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
infinite dataloader."""
try:
# try getting the length
if len(dataloader) == 0:
if len(dataloader) == 0: # type: ignore [arg-type]
rank_zero_warn(
f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention."
)
@ -58,7 +59,7 @@ def has_len(dataloader: Union[DataLoader, Iterable]) -> bool:
except (TypeError, NotImplementedError):
has_len = False
if has_len and has_iterable_dataset(dataloader):
if has_len and isinstance(dataloader, DataLoader) and has_iterable_dataset(dataloader):
rank_zero_warn(
"Your `IterableDataset` has `__len__` defined."
" In combination with multi-process data loading (when num_workers > 1),"
@ -76,7 +77,7 @@ def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable]
def _get_dataloader_init_args_and_kwargs(
dataloader: DataLoader,
sampler: Optional[Sampler],
sampler: Union[Sampler, Iterable],
disallow_batch_sampler: bool = False,
) -> Tuple[Tuple[Any], Dict[str, Any]]:
if not isinstance(dataloader, DataLoader):
@ -99,7 +100,7 @@ def _get_dataloader_init_args_and_kwargs(
arg_names = ()
# get the dataloader instance `__init__` parameters
params = dict(inspect.signature(dataloader.__init__).parameters)
params = dict(inspect.signature(dataloader.__init__).parameters) # type: ignore[misc]
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
if has_variadic_kwargs:
# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
@ -141,14 +142,14 @@ def _get_dataloader_init_args_and_kwargs(
}
# the dataloader has required args which we could not extract from the existing attributes
if required_args:
required_args = sorted(required_args)
sorted_required_args = sorted(required_args)
dataloader_cls_name = dataloader.__class__.__name__
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args)
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in sorted_required_args)
raise MisconfigurationException(
f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a "
"`*_dataloader` hook of your module, we will do this for you."
f"The missing attributes are {sorted_required_args}. If you instantiate your `{dataloader_cls_name}` "
"inside a `*_dataloader` hook of your module, we will do this for you."
f" Otherwise, define {missing_args_message} inside your `__init__`."
)
@ -156,13 +157,13 @@ def _get_dataloader_init_args_and_kwargs(
# the dataloader signature does not allow keyword arguments that need to be passed
missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys()
if missing_kwargs:
missing_kwargs = sorted(missing_kwargs)
sorted_missing_kwargs = sorted(missing_kwargs)
dataloader_cls_name = dataloader.__class__.__name__
raise TypeError(
f"Trying to inject parameters into the `{dataloader_cls_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, "
"add the `__init__` arguments or allow passing `**kwargs`"
f"The missing arguments are {sorted_missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` "
"class, add the `__init__` arguments or allow passing `**kwargs`"
)
return dl_args, dl_kwargs
@ -170,7 +171,7 @@ def _get_dataloader_init_args_and_kwargs(
def _dataloader_init_kwargs_resolve_sampler(
dataloader: DataLoader,
sampler: Optional[Sampler],
sampler: Union[Sampler, Iterable],
disallow_batch_sampler: bool = False,
) -> Dict[str, Any]:
"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
@ -334,7 +335,7 @@ def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable:
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
@functools.wraps(method)
def wrapper(obj: Any, *args: Any):
def wrapper(obj: Any, *args: Any) -> None:
# First, let's find out if we're the first in inheritance chain calling the patched method.
name, *_ = args
prev_call_name, prev_call_method = getattr(obj, "__pl_current_call", (None, "method"))

View File

@ -245,7 +245,7 @@ class IPUStrategy(ParallelStrategy):
return dataloader
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(
dataloader, sampler, mode, self.replication_factor > 1 # type: ignore[arg-type]
dataloader, sampler, mode, self.replication_factor > 1
)
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
dataloader = _reinstantiate_wrapped_cls(

View File

@ -62,7 +62,7 @@ class FastForwardSampler(Sampler):
samples seen in the last iterations (for the current worker).
"""
def __init__(self, sampler: Iterator, attr_name: Optional[str] = None) -> None:
def __init__(self, sampler: Union[Sampler, Iterable], attr_name: Optional[str] = None) -> None:
super().__init__(data_source=None)
self._sampler = sampler
self.restarting: bool = False

View File

@ -30,7 +30,6 @@ from torch.utils.data import (
)
import pytorch_lightning as pl
from lightning_lite.utilities import LightningEnum
from lightning_lite.utilities.data import _reinstantiate_wrapped_cls, _replace_value_in_saved_args
from lightning_lite.utilities.data import has_iterable_dataset as new_has_iterable_dataset
from lightning_lite.utilities.data import has_len as new_has_len
@ -41,24 +40,13 @@ from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]]
# might be supported in later releases, see https://github.com/python/mypy/pull/13297
BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] # type: ignore[misc]
warning_cache = WarningCache()
class _WrapAttrTag(LightningEnum):
SET = "set"
DEL = "del"
def __call__(self, *args):
if self == self.SET:
fn = setattr
else:
fn = delattr
return fn(*args)
def _extract_batch_size(batch: BType) -> Generator[int, None, None]:
def _extract_batch_size(batch: BType) -> Generator[Optional[int], None, None]:
if isinstance(batch, Tensor):
if batch.ndim == 0:
yield 1
@ -109,7 +97,7 @@ def extract_batch_size(batch: BType) -> int:
def has_len_all_ranks(
dataloader: DataLoader,
strategy: "pl.Strategy",
strategy: "pl.strategies.Strategy",
model: Union["pl.LightningModule", "pl.LightningDataModule"],
) -> bool:
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
@ -151,14 +139,14 @@ def has_len_all_ranks(
return has_len
def get_len(dataloader: DataLoader) -> Union[int, float]:
def get_len(dataloader: Union[DataLoader, Dataset]) -> Union[int, float]:
"""Return the length of the given DataLoader.
If ``__len__`` method is not implemented, return float('inf').
"""
if new_has_len(dataloader):
return len(dataloader)
return len(dataloader) # type: ignore [arg-type]
return float("inf")
@ -173,7 +161,7 @@ def _update_dataloader(
def _get_dataloader_init_args_and_kwargs(
dataloader: DataLoader,
sampler: Optional[Sampler],
sampler: Union[Sampler, Iterable],
mode: Optional[RunningStage] = None,
disallow_batch_sampler: bool = False,
) -> Tuple[Tuple[Any], Dict[str, Any]]:
@ -197,7 +185,7 @@ def _get_dataloader_init_args_and_kwargs(
arg_names = ()
# get the dataloader instance `__init__` parameters
params = dict(inspect.signature(dataloader.__init__).parameters)
params = dict(inspect.signature(dataloader.__init__).parameters) # type: ignore[misc]
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
if has_variadic_kwargs:
# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
@ -239,14 +227,14 @@ def _get_dataloader_init_args_and_kwargs(
}
# the dataloader has required args which we could not extract from the existing attributes
if required_args:
required_args = sorted(required_args)
sorted_required_args = sorted(required_args)
dataloader_cls_name = dataloader.__class__.__name__
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args)
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in sorted_required_args)
raise MisconfigurationException(
f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a "
"`*_dataloader` hook of your module, we will do this for you."
f"The missing attributes are {sorted_required_args}. If you instantiate your `{dataloader_cls_name}` "
"inside a `*_dataloader` hook of your module, we will do this for you."
f" Otherwise, define {missing_args_message} inside your `__init__`."
)
@ -254,13 +242,13 @@ def _get_dataloader_init_args_and_kwargs(
# the dataloader signature does not allow keyword arguments that need to be passed
missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys()
if missing_kwargs:
missing_kwargs = sorted(missing_kwargs)
sorted_missing_kwargs = sorted(missing_kwargs)
dataloader_cls_name = dataloader.__class__.__name__
raise MisconfigurationException(
f"Trying to inject parameters into the `{dataloader_cls_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, "
"add the `__init__` arguments or allow passing `**kwargs`"
f"The missing arguments are {sorted_missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` "
"class, add the `__init__` arguments or allow passing `**kwargs`"
)
if _FaultTolerantMode.detect_current_mode().is_automatic:
@ -273,7 +261,7 @@ def _get_dataloader_init_args_and_kwargs(
def _dataloader_init_kwargs_resolve_sampler(
dataloader: DataLoader,
sampler: Optional[Sampler],
sampler: Union[Sampler, Iterable],
mode: Optional[RunningStage] = None,
disallow_batch_sampler: bool = False,
) -> Dict[str, Any]: