From 8e9780bd5bd1275cc417dafe3b093f4968b2aaed Mon Sep 17 00:00:00 2001 From: Ritik Nandwal <48522685+nandwalritik@users.noreply.github.com> Date: Wed, 14 Sep 2022 16:51:57 +0530 Subject: [PATCH] fix mypy typing errors in pytorch_lightning.utilities.data.py (#13901) Co-authored-by: otaj --- pyproject.toml | 4 +- src/lightning_lite/utilities/data.py | 33 +++++++------- src/pytorch_lightning/strategies/ipu.py | 2 +- .../utilities/auto_restart.py | 2 +- src/pytorch_lightning/utilities/data.py | 44 +++++++------------ 5 files changed, 36 insertions(+), 49 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 777f86841a..dbf58177e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,8 +53,6 @@ warn_no_return = "False" module = [ "pytorch_lightning.callbacks.progress.rich_progress", "pytorch_lightning.trainer.trainer", - "pytorch_lightning.tuner.batch_size_scaling", - "pytorch_lightning.utilities.data", - "lightning_lite.utilities.data", + "pytorch_lightning.tuner.batch_size_scaling" ] ignore_errors = "True" diff --git a/src/lightning_lite/utilities/data.py b/src/lightning_lite/utilities/data.py index cdaf806a0c..ca50344567 100644 --- a/src/lightning_lite/utilities/data.py +++ b/src/lightning_lite/utilities/data.py @@ -21,7 +21,7 @@ from functools import partial from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, Type, Union from lightning_utilities.core.inheritance import get_all_subclasses -from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler +from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Sampler from lightning_lite.utilities.enums import LightningEnum from lightning_lite.utilities.exceptions import MisconfigurationException @@ -33,7 +33,8 @@ class _WrapAttrTag(LightningEnum): SET = "set" DEL = "del" - def __call__(self, *args): + def __call__(self, *args: Any) -> None: + fn: Union[Callable[[object, str], None], Callable[[object, str, Any], None]] if self == self.SET: fn = setattr else: @@ -45,12 +46,12 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool: return hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset) -def has_len(dataloader: Union[DataLoader, Iterable]) -> bool: +def has_len(dataloader: Union[DataLoader, Iterable, Dataset]) -> bool: """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or infinite dataloader.""" try: # try getting the length - if len(dataloader) == 0: + if len(dataloader) == 0: # type: ignore [arg-type] rank_zero_warn( f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention." ) @@ -58,7 +59,7 @@ def has_len(dataloader: Union[DataLoader, Iterable]) -> bool: except (TypeError, NotImplementedError): has_len = False - if has_len and has_iterable_dataset(dataloader): + if has_len and isinstance(dataloader, DataLoader) and has_iterable_dataset(dataloader): rank_zero_warn( "Your `IterableDataset` has `__len__` defined." " In combination with multi-process data loading (when num_workers > 1)," @@ -76,7 +77,7 @@ def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable] def _get_dataloader_init_args_and_kwargs( dataloader: DataLoader, - sampler: Optional[Sampler], + sampler: Union[Sampler, Iterable], disallow_batch_sampler: bool = False, ) -> Tuple[Tuple[Any], Dict[str, Any]]: if not isinstance(dataloader, DataLoader): @@ -99,7 +100,7 @@ def _get_dataloader_init_args_and_kwargs( arg_names = () # get the dataloader instance `__init__` parameters - params = dict(inspect.signature(dataloader.__init__).parameters) + params = dict(inspect.signature(dataloader.__init__).parameters) # type: ignore[misc] has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) if has_variadic_kwargs: # if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)` @@ -141,14 +142,14 @@ def _get_dataloader_init_args_and_kwargs( } # the dataloader has required args which we could not extract from the existing attributes if required_args: - required_args = sorted(required_args) + sorted_required_args = sorted(required_args) dataloader_cls_name = dataloader.__class__.__name__ - missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args) + missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in sorted_required_args) raise MisconfigurationException( f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. " "This would fail as some of the `__init__` arguments are not available as instance attributes. " - f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a " - "`*_dataloader` hook of your module, we will do this for you." + f"The missing attributes are {sorted_required_args}. If you instantiate your `{dataloader_cls_name}` " + "inside a `*_dataloader` hook of your module, we will do this for you." f" Otherwise, define {missing_args_message} inside your `__init__`." ) @@ -156,13 +157,13 @@ def _get_dataloader_init_args_and_kwargs( # the dataloader signature does not allow keyword arguments that need to be passed missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys() if missing_kwargs: - missing_kwargs = sorted(missing_kwargs) + sorted_missing_kwargs = sorted(missing_kwargs) dataloader_cls_name = dataloader.__class__.__name__ raise TypeError( f"Trying to inject parameters into the `{dataloader_cls_name}` instance. " "This would fail as it doesn't expose all its attributes in the `__init__` signature. " - f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, " - "add the `__init__` arguments or allow passing `**kwargs`" + f"The missing arguments are {sorted_missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` " + "class, add the `__init__` arguments or allow passing `**kwargs`" ) return dl_args, dl_kwargs @@ -170,7 +171,7 @@ def _get_dataloader_init_args_and_kwargs( def _dataloader_init_kwargs_resolve_sampler( dataloader: DataLoader, - sampler: Optional[Sampler], + sampler: Union[Sampler, Iterable], disallow_batch_sampler: bool = False, ) -> Dict[str, Any]: """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its @@ -334,7 +335,7 @@ def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable: :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses.""" @functools.wraps(method) - def wrapper(obj: Any, *args: Any): + def wrapper(obj: Any, *args: Any) -> None: # First, let's find out if we're the first in inheritance chain calling the patched method. name, *_ = args prev_call_name, prev_call_method = getattr(obj, "__pl_current_call", (None, "method")) diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 64898e6c76..966789a07f 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -245,7 +245,7 @@ class IPUStrategy(ParallelStrategy): return dataloader dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs( - dataloader, sampler, mode, self.replication_factor > 1 # type: ignore[arg-type] + dataloader, sampler, mode, self.replication_factor > 1 ) opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts dataloader = _reinstantiate_wrapped_cls( diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index d9d8c5da38..34033b898f 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -62,7 +62,7 @@ class FastForwardSampler(Sampler): samples seen in the last iterations (for the current worker). """ - def __init__(self, sampler: Iterator, attr_name: Optional[str] = None) -> None: + def __init__(self, sampler: Union[Sampler, Iterable], attr_name: Optional[str] = None) -> None: super().__init__(data_source=None) self._sampler = sampler self.restarting: bool = False diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index cf07949461..17f8b9f101 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -30,7 +30,6 @@ from torch.utils.data import ( ) import pytorch_lightning as pl -from lightning_lite.utilities import LightningEnum from lightning_lite.utilities.data import _reinstantiate_wrapped_cls, _replace_value_in_saved_args from lightning_lite.utilities.data import has_iterable_dataset as new_has_iterable_dataset from lightning_lite.utilities.data import has_len as new_has_len @@ -41,24 +40,13 @@ from pytorch_lightning.utilities.enums import _FaultTolerantMode from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn -BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] +# might be supported in later releases, see https://github.com/python/mypy/pull/13297 +BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] # type: ignore[misc] warning_cache = WarningCache() -class _WrapAttrTag(LightningEnum): - SET = "set" - DEL = "del" - - def __call__(self, *args): - if self == self.SET: - fn = setattr - else: - fn = delattr - return fn(*args) - - -def _extract_batch_size(batch: BType) -> Generator[int, None, None]: +def _extract_batch_size(batch: BType) -> Generator[Optional[int], None, None]: if isinstance(batch, Tensor): if batch.ndim == 0: yield 1 @@ -109,7 +97,7 @@ def extract_batch_size(batch: BType) -> int: def has_len_all_ranks( dataloader: DataLoader, - strategy: "pl.Strategy", + strategy: "pl.strategies.Strategy", model: Union["pl.LightningModule", "pl.LightningDataModule"], ) -> bool: """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or @@ -151,14 +139,14 @@ def has_len_all_ranks( return has_len -def get_len(dataloader: DataLoader) -> Union[int, float]: +def get_len(dataloader: Union[DataLoader, Dataset]) -> Union[int, float]: """Return the length of the given DataLoader. If ``__len__`` method is not implemented, return float('inf'). """ if new_has_len(dataloader): - return len(dataloader) + return len(dataloader) # type: ignore [arg-type] return float("inf") @@ -173,7 +161,7 @@ def _update_dataloader( def _get_dataloader_init_args_and_kwargs( dataloader: DataLoader, - sampler: Optional[Sampler], + sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None, disallow_batch_sampler: bool = False, ) -> Tuple[Tuple[Any], Dict[str, Any]]: @@ -197,7 +185,7 @@ def _get_dataloader_init_args_and_kwargs( arg_names = () # get the dataloader instance `__init__` parameters - params = dict(inspect.signature(dataloader.__init__).parameters) + params = dict(inspect.signature(dataloader.__init__).parameters) # type: ignore[misc] has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) if has_variadic_kwargs: # if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)` @@ -239,14 +227,14 @@ def _get_dataloader_init_args_and_kwargs( } # the dataloader has required args which we could not extract from the existing attributes if required_args: - required_args = sorted(required_args) + sorted_required_args = sorted(required_args) dataloader_cls_name = dataloader.__class__.__name__ - missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args) + missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in sorted_required_args) raise MisconfigurationException( f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. " "This would fail as some of the `__init__` arguments are not available as instance attributes. " - f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a " - "`*_dataloader` hook of your module, we will do this for you." + f"The missing attributes are {sorted_required_args}. If you instantiate your `{dataloader_cls_name}` " + "inside a `*_dataloader` hook of your module, we will do this for you." f" Otherwise, define {missing_args_message} inside your `__init__`." ) @@ -254,13 +242,13 @@ def _get_dataloader_init_args_and_kwargs( # the dataloader signature does not allow keyword arguments that need to be passed missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys() if missing_kwargs: - missing_kwargs = sorted(missing_kwargs) + sorted_missing_kwargs = sorted(missing_kwargs) dataloader_cls_name = dataloader.__class__.__name__ raise MisconfigurationException( f"Trying to inject parameters into the `{dataloader_cls_name}` instance. " "This would fail as it doesn't expose all its attributes in the `__init__` signature. " - f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, " - "add the `__init__` arguments or allow passing `**kwargs`" + f"The missing arguments are {sorted_missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` " + "class, add the `__init__` arguments or allow passing `**kwargs`" ) if _FaultTolerantMode.detect_current_mode().is_automatic: @@ -273,7 +261,7 @@ def _get_dataloader_init_args_and_kwargs( def _dataloader_init_kwargs_resolve_sampler( dataloader: DataLoader, - sampler: Optional[Sampler], + sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None, disallow_batch_sampler: bool = False, ) -> Dict[str, Any]: