fix mypy typing errors in pytorch_lightning.utilities.data.py (#13901)
Co-authored-by: otaj <ota@lightning.ai>
This commit is contained in:
parent
9b01a0fd32
commit
8e9780bd5b
|
@ -53,8 +53,6 @@ warn_no_return = "False"
|
|||
module = [
|
||||
"pytorch_lightning.callbacks.progress.rich_progress",
|
||||
"pytorch_lightning.trainer.trainer",
|
||||
"pytorch_lightning.tuner.batch_size_scaling",
|
||||
"pytorch_lightning.utilities.data",
|
||||
"lightning_lite.utilities.data",
|
||||
"pytorch_lightning.tuner.batch_size_scaling"
|
||||
]
|
||||
ignore_errors = "True"
|
||||
|
|
|
@ -21,7 +21,7 @@ from functools import partial
|
|||
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, Type, Union
|
||||
|
||||
from lightning_utilities.core.inheritance import get_all_subclasses
|
||||
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler
|
||||
from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Sampler
|
||||
|
||||
from lightning_lite.utilities.enums import LightningEnum
|
||||
from lightning_lite.utilities.exceptions import MisconfigurationException
|
||||
|
@ -33,7 +33,8 @@ class _WrapAttrTag(LightningEnum):
|
|||
SET = "set"
|
||||
DEL = "del"
|
||||
|
||||
def __call__(self, *args):
|
||||
def __call__(self, *args: Any) -> None:
|
||||
fn: Union[Callable[[object, str], None], Callable[[object, str, Any], None]]
|
||||
if self == self.SET:
|
||||
fn = setattr
|
||||
else:
|
||||
|
@ -45,12 +46,12 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool:
|
|||
return hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset)
|
||||
|
||||
|
||||
def has_len(dataloader: Union[DataLoader, Iterable]) -> bool:
|
||||
def has_len(dataloader: Union[DataLoader, Iterable, Dataset]) -> bool:
|
||||
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
|
||||
infinite dataloader."""
|
||||
try:
|
||||
# try getting the length
|
||||
if len(dataloader) == 0:
|
||||
if len(dataloader) == 0: # type: ignore [arg-type]
|
||||
rank_zero_warn(
|
||||
f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention."
|
||||
)
|
||||
|
@ -58,7 +59,7 @@ def has_len(dataloader: Union[DataLoader, Iterable]) -> bool:
|
|||
except (TypeError, NotImplementedError):
|
||||
has_len = False
|
||||
|
||||
if has_len and has_iterable_dataset(dataloader):
|
||||
if has_len and isinstance(dataloader, DataLoader) and has_iterable_dataset(dataloader):
|
||||
rank_zero_warn(
|
||||
"Your `IterableDataset` has `__len__` defined."
|
||||
" In combination with multi-process data loading (when num_workers > 1),"
|
||||
|
@ -76,7 +77,7 @@ def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable]
|
|||
|
||||
def _get_dataloader_init_args_and_kwargs(
|
||||
dataloader: DataLoader,
|
||||
sampler: Optional[Sampler],
|
||||
sampler: Union[Sampler, Iterable],
|
||||
disallow_batch_sampler: bool = False,
|
||||
) -> Tuple[Tuple[Any], Dict[str, Any]]:
|
||||
if not isinstance(dataloader, DataLoader):
|
||||
|
@ -99,7 +100,7 @@ def _get_dataloader_init_args_and_kwargs(
|
|||
arg_names = ()
|
||||
|
||||
# get the dataloader instance `__init__` parameters
|
||||
params = dict(inspect.signature(dataloader.__init__).parameters)
|
||||
params = dict(inspect.signature(dataloader.__init__).parameters) # type: ignore[misc]
|
||||
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
|
||||
if has_variadic_kwargs:
|
||||
# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
|
||||
|
@ -141,14 +142,14 @@ def _get_dataloader_init_args_and_kwargs(
|
|||
}
|
||||
# the dataloader has required args which we could not extract from the existing attributes
|
||||
if required_args:
|
||||
required_args = sorted(required_args)
|
||||
sorted_required_args = sorted(required_args)
|
||||
dataloader_cls_name = dataloader.__class__.__name__
|
||||
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args)
|
||||
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in sorted_required_args)
|
||||
raise MisconfigurationException(
|
||||
f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. "
|
||||
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
|
||||
f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a "
|
||||
"`*_dataloader` hook of your module, we will do this for you."
|
||||
f"The missing attributes are {sorted_required_args}. If you instantiate your `{dataloader_cls_name}` "
|
||||
"inside a `*_dataloader` hook of your module, we will do this for you."
|
||||
f" Otherwise, define {missing_args_message} inside your `__init__`."
|
||||
)
|
||||
|
||||
|
@ -156,13 +157,13 @@ def _get_dataloader_init_args_and_kwargs(
|
|||
# the dataloader signature does not allow keyword arguments that need to be passed
|
||||
missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys()
|
||||
if missing_kwargs:
|
||||
missing_kwargs = sorted(missing_kwargs)
|
||||
sorted_missing_kwargs = sorted(missing_kwargs)
|
||||
dataloader_cls_name = dataloader.__class__.__name__
|
||||
raise TypeError(
|
||||
f"Trying to inject parameters into the `{dataloader_cls_name}` instance. "
|
||||
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
|
||||
f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, "
|
||||
"add the `__init__` arguments or allow passing `**kwargs`"
|
||||
f"The missing arguments are {sorted_missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` "
|
||||
"class, add the `__init__` arguments or allow passing `**kwargs`"
|
||||
)
|
||||
|
||||
return dl_args, dl_kwargs
|
||||
|
@ -170,7 +171,7 @@ def _get_dataloader_init_args_and_kwargs(
|
|||
|
||||
def _dataloader_init_kwargs_resolve_sampler(
|
||||
dataloader: DataLoader,
|
||||
sampler: Optional[Sampler],
|
||||
sampler: Union[Sampler, Iterable],
|
||||
disallow_batch_sampler: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its
|
||||
|
@ -334,7 +335,7 @@ def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable:
|
|||
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
|
||||
|
||||
@functools.wraps(method)
|
||||
def wrapper(obj: Any, *args: Any):
|
||||
def wrapper(obj: Any, *args: Any) -> None:
|
||||
# First, let's find out if we're the first in inheritance chain calling the patched method.
|
||||
name, *_ = args
|
||||
prev_call_name, prev_call_method = getattr(obj, "__pl_current_call", (None, "method"))
|
||||
|
|
|
@ -245,7 +245,7 @@ class IPUStrategy(ParallelStrategy):
|
|||
return dataloader
|
||||
|
||||
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(
|
||||
dataloader, sampler, mode, self.replication_factor > 1 # type: ignore[arg-type]
|
||||
dataloader, sampler, mode, self.replication_factor > 1
|
||||
)
|
||||
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
|
||||
dataloader = _reinstantiate_wrapped_cls(
|
||||
|
|
|
@ -62,7 +62,7 @@ class FastForwardSampler(Sampler):
|
|||
samples seen in the last iterations (for the current worker).
|
||||
"""
|
||||
|
||||
def __init__(self, sampler: Iterator, attr_name: Optional[str] = None) -> None:
|
||||
def __init__(self, sampler: Union[Sampler, Iterable], attr_name: Optional[str] = None) -> None:
|
||||
super().__init__(data_source=None)
|
||||
self._sampler = sampler
|
||||
self.restarting: bool = False
|
||||
|
|
|
@ -30,7 +30,6 @@ from torch.utils.data import (
|
|||
)
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities import LightningEnum
|
||||
from lightning_lite.utilities.data import _reinstantiate_wrapped_cls, _replace_value_in_saved_args
|
||||
from lightning_lite.utilities.data import has_iterable_dataset as new_has_iterable_dataset
|
||||
from lightning_lite.utilities.data import has_len as new_has_len
|
||||
|
@ -41,24 +40,13 @@ from pytorch_lightning.utilities.enums import _FaultTolerantMode
|
|||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn
|
||||
|
||||
BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]]
|
||||
# might be supported in later releases, see https://github.com/python/mypy/pull/13297
|
||||
BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] # type: ignore[misc]
|
||||
|
||||
warning_cache = WarningCache()
|
||||
|
||||
|
||||
class _WrapAttrTag(LightningEnum):
|
||||
SET = "set"
|
||||
DEL = "del"
|
||||
|
||||
def __call__(self, *args):
|
||||
if self == self.SET:
|
||||
fn = setattr
|
||||
else:
|
||||
fn = delattr
|
||||
return fn(*args)
|
||||
|
||||
|
||||
def _extract_batch_size(batch: BType) -> Generator[int, None, None]:
|
||||
def _extract_batch_size(batch: BType) -> Generator[Optional[int], None, None]:
|
||||
if isinstance(batch, Tensor):
|
||||
if batch.ndim == 0:
|
||||
yield 1
|
||||
|
@ -109,7 +97,7 @@ def extract_batch_size(batch: BType) -> int:
|
|||
|
||||
def has_len_all_ranks(
|
||||
dataloader: DataLoader,
|
||||
strategy: "pl.Strategy",
|
||||
strategy: "pl.strategies.Strategy",
|
||||
model: Union["pl.LightningModule", "pl.LightningDataModule"],
|
||||
) -> bool:
|
||||
"""Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or
|
||||
|
@ -151,14 +139,14 @@ def has_len_all_ranks(
|
|||
return has_len
|
||||
|
||||
|
||||
def get_len(dataloader: DataLoader) -> Union[int, float]:
|
||||
def get_len(dataloader: Union[DataLoader, Dataset]) -> Union[int, float]:
|
||||
"""Return the length of the given DataLoader.
|
||||
|
||||
If ``__len__`` method is not implemented, return float('inf').
|
||||
"""
|
||||
|
||||
if new_has_len(dataloader):
|
||||
return len(dataloader)
|
||||
return len(dataloader) # type: ignore [arg-type]
|
||||
|
||||
return float("inf")
|
||||
|
||||
|
@ -173,7 +161,7 @@ def _update_dataloader(
|
|||
|
||||
def _get_dataloader_init_args_and_kwargs(
|
||||
dataloader: DataLoader,
|
||||
sampler: Optional[Sampler],
|
||||
sampler: Union[Sampler, Iterable],
|
||||
mode: Optional[RunningStage] = None,
|
||||
disallow_batch_sampler: bool = False,
|
||||
) -> Tuple[Tuple[Any], Dict[str, Any]]:
|
||||
|
@ -197,7 +185,7 @@ def _get_dataloader_init_args_and_kwargs(
|
|||
arg_names = ()
|
||||
|
||||
# get the dataloader instance `__init__` parameters
|
||||
params = dict(inspect.signature(dataloader.__init__).parameters)
|
||||
params = dict(inspect.signature(dataloader.__init__).parameters) # type: ignore[misc]
|
||||
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values())
|
||||
if has_variadic_kwargs:
|
||||
# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`
|
||||
|
@ -239,14 +227,14 @@ def _get_dataloader_init_args_and_kwargs(
|
|||
}
|
||||
# the dataloader has required args which we could not extract from the existing attributes
|
||||
if required_args:
|
||||
required_args = sorted(required_args)
|
||||
sorted_required_args = sorted(required_args)
|
||||
dataloader_cls_name = dataloader.__class__.__name__
|
||||
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args)
|
||||
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in sorted_required_args)
|
||||
raise MisconfigurationException(
|
||||
f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. "
|
||||
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
|
||||
f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a "
|
||||
"`*_dataloader` hook of your module, we will do this for you."
|
||||
f"The missing attributes are {sorted_required_args}. If you instantiate your `{dataloader_cls_name}` "
|
||||
"inside a `*_dataloader` hook of your module, we will do this for you."
|
||||
f" Otherwise, define {missing_args_message} inside your `__init__`."
|
||||
)
|
||||
|
||||
|
@ -254,13 +242,13 @@ def _get_dataloader_init_args_and_kwargs(
|
|||
# the dataloader signature does not allow keyword arguments that need to be passed
|
||||
missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys()
|
||||
if missing_kwargs:
|
||||
missing_kwargs = sorted(missing_kwargs)
|
||||
sorted_missing_kwargs = sorted(missing_kwargs)
|
||||
dataloader_cls_name = dataloader.__class__.__name__
|
||||
raise MisconfigurationException(
|
||||
f"Trying to inject parameters into the `{dataloader_cls_name}` instance. "
|
||||
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
|
||||
f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, "
|
||||
"add the `__init__` arguments or allow passing `**kwargs`"
|
||||
f"The missing arguments are {sorted_missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` "
|
||||
"class, add the `__init__` arguments or allow passing `**kwargs`"
|
||||
)
|
||||
|
||||
if _FaultTolerantMode.detect_current_mode().is_automatic:
|
||||
|
@ -273,7 +261,7 @@ def _get_dataloader_init_args_and_kwargs(
|
|||
|
||||
def _dataloader_init_kwargs_resolve_sampler(
|
||||
dataloader: DataLoader,
|
||||
sampler: Optional[Sampler],
|
||||
sampler: Union[Sampler, Iterable],
|
||||
mode: Optional[RunningStage] = None,
|
||||
disallow_batch_sampler: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
|
|
Loading…
Reference in New Issue