diff --git a/pyproject.toml b/pyproject.toml index df89de3d09..5b62baf9ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,6 @@ module = [ "pytorch_lightning.trainer.supporters", "pytorch_lightning.trainer.trainer", "pytorch_lightning.tuner.batch_size_scaling", - "pytorch_lightning.utilities.auto_restart", "pytorch_lightning.utilities.data", ] ignore_errors = "True" diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index 3877a1ab39..e90dcc7172 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sized from copy import deepcopy from dataclasses import dataclass, field from functools import partial, wraps @@ -24,6 +25,7 @@ from torch.utils.data.dataloader import ( DataLoader, IterableDataset, ) +from typing_extensions import TypedDict import pytorch_lightning as pl from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -34,6 +36,21 @@ from pytorch_lightning.utilities.seed import _collect_rng_states, _set_rng_state from pytorch_lightning.utilities.types import _Stateful +class _IteratorStateDict(TypedDict): + dataset_state: Dict[int, Any] + sampler_state: Dict[int, Any] + worker_id: int + num_workers: int + num_batches_fetched: int + name: Optional[str] + + +class _MergedIteratorStateDict(TypedDict): + state: Dict[str, Any] + latest_worker_id: int + represent_map_dataset: Optional[bool] + + class FastForwardSampler(Sampler): """This FastForwardSampler wraps a :class:`torch.utils.data.Sampler` and records the number of iterations performed during an epoch. @@ -45,7 +62,7 @@ class FastForwardSampler(Sampler): samples seen in the last iterations (for the current worker). """ - def __init__(self, sampler: Union[Sampler, Generator], attr_name: Optional[str] = None) -> None: + def __init__(self, sampler: Iterator, attr_name: Optional[str] = None) -> None: super().__init__(data_source=None) self._sampler = sampler self.restarting: bool = False @@ -79,7 +96,7 @@ class FastForwardSampler(Sampler): self._counter = 0 return self - def __next__(self): + def __next__(self) -> Any: # the `state dict` was cached as workers were unavailable before. if self._cached_state_dict is not None: self._load_non_random_state(self._cached_state_dict) @@ -109,6 +126,7 @@ class FastForwardSampler(Sampler): raise StopIteration def __len__(self) -> int: + assert isinstance(self._sampler, Sized) return len(self._sampler) def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, int]]: @@ -161,7 +179,7 @@ class IteratorState: name: Optional[str] = None @classmethod - def from_state_dict(cls, state_dict) -> "IteratorState": + def from_state_dict(cls, state_dict: _IteratorStateDict) -> "IteratorState": return cls(**state_dict) @@ -173,22 +191,22 @@ class MergedIteratorState: worker states in this merged iterator state. """ - state: Union[Dict[Union[int, str], Union[Dict[str, IteratorState], IteratorState]]] = field(default_factory=dict) + state: Dict = field(default_factory=dict) latest_worker_id: int = 0 represent_map_dataset: Optional[bool] = None def update(self, generator_name: Optional[str], new_state: IteratorState) -> None: # a map based dataset doesn't own a generator and therefore `generator_name` should be None. self.represent_map_dataset = generator_name is None - if self.represent_map_dataset: - state = self.state + latest_worker_id = new_state.worker_id + if generator_name is None: + self.state[latest_worker_id] = new_state else: if generator_name not in self.state: self.state[generator_name] = {} state = self.state[generator_name] + state[latest_worker_id] = new_state - latest_worker_id = new_state.worker_id - state[latest_worker_id] = new_state self.latest_worker_id = latest_worker_id @property @@ -202,7 +220,7 @@ class MergedIteratorState: return {k: self.state[k].dataset_state[k] for k in self.state.keys()} @classmethod - def from_state_dict(cls, state_dict) -> "MergedIteratorState": + def from_state_dict(cls, state_dict: _MergedIteratorStateDict) -> "MergedIteratorState": if state_dict["represent_map_dataset"]: state_dict["state"] = { worker_id: IteratorState.from_state_dict(state) for worker_id, state in state_dict["state"].items() @@ -229,15 +247,15 @@ class CaptureMapDataset(Dataset): """ def __init__(self, dataset: Dataset) -> None: - self.dataset = dataset - self._cached_state_dict = None + self.dataset: Dataset = dataset + self._cached_state_dict: Optional[Dict[int, Any]] = None @property def worker_id(self) -> int: worker_info = get_worker_info() return worker_info.id if worker_info else 0 - def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]: + def __getitem__(self, item: int) -> Tuple[Any, Dict[int, Dict]]: if self._cached_state_dict is not None: if self.worker_id in self._cached_state_dict: _set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"]) @@ -246,6 +264,7 @@ class CaptureMapDataset(Dataset): return self.dataset[item] def __len__(self) -> int: + assert isinstance(self.dataset, Sized) return len(self.dataset) def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None: @@ -268,7 +287,7 @@ class CaptureIterableDataset(IterableDataset): super().__init__() self.dataset = deepcopy(dataset) self.samplers: Optional[Dict[str, FastForwardSampler]] = None - self._state_dict: Optional[Dict[int, Any]] = None + self._state_dict: Optional[Dict[str, Any]] = None self._has_wrapped: bool = False @property @@ -276,9 +295,10 @@ class CaptureIterableDataset(IterableDataset): return self.dataset.sampler def state_dict(self) -> Dict[str, Any]: + assert self.samplers is not None return {k: v.state_dict() for k, v in self.samplers.items()} - def load_state_dict(self, state_dict: Dict[int, Any]) -> None: + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._state_dict = deepcopy(state_dict) def _wrap_generator_samplers(self) -> None: @@ -311,7 +331,7 @@ class CaptureIterableDataset(IterableDataset): self.reset_on_epoch() - def reset_on_epoch(self): + def reset_on_epoch(self) -> None: self._state_dict = None def __iter__(self) -> Iterator: @@ -371,8 +391,8 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str for _ in range(state_dict["previous_worker"] - 1): next(iter_dataloader._worker_queue_idx_cycle) - # we can finally call reset and apply prefecthing. - iter_dataloader._reset = iter_dataloader._original_reset + # we can finally call reset and apply prefetching. + iter_dataloader._reset = iter_dataloader._original_reset # type: ignore[assignment] iter_dataloader._reset(dataloader, first_iter=True) # return the iterator return iter_dataloader @@ -445,6 +465,7 @@ def _next_data_wrapper( ] elif isinstance(dataset, CaptureMapDataset): ff_sampler = _find_fast_forward_samplers(dl) + assert ff_sampler is not None state = [ IteratorState( num_workers=dl.num_workers, @@ -519,6 +540,7 @@ def _reload_dataloader_state_dict_automatic_map_dataset(dataloader: DataLoader, # reload sampler state ff_sampler = _find_fast_forward_samplers(dataloader) + assert ff_sampler is not None ff_sampler.load_state_dict(iterator_state.sampler_state) # reload dataset state @@ -610,18 +632,20 @@ def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_wor return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state} -class _StatefulDataLoaderIter: +class _StatefulDataLoaderIter(_BaseDataLoaderIter): """This mixin is used to make PyTorch DataLoaderIter stateful.""" - def __accumulate_state(self, sampler_state: Dict[str, Any]) -> None: + def __accumulate_state(self, sampler_state: Dict[int, Any]) -> None: # store sampler state within a queue alongside its idx. - self._sampler_state_idx = getattr(self, "_sampler_state_idx", 0) + 1 + self._sampler_state_idx: int = getattr(self, "_sampler_state_idx", 0) + 1 self._sampler_state.append((sampler_state, self._sampler_state_idx)) def _store_sampler_state(self) -> None: """This function is used to extract the sampler states if any.""" - sampler_state = { - k: v.state_dict() for k, v in self._loader.__dict__.items() if isinstance(v, _Stateful) and k != "dataset" + sampler_state: Dict[int, Any] = { + k: v.state_dict() # type: ignore[misc] + for k, v in self._loader.__dict__.items() + if isinstance(v, _Stateful) and k != "dataset" } self.__accumulate_state(sampler_state) @@ -630,12 +654,12 @@ class _StatefulDataLoaderIter: self._store_sampler_state() return indexes - def _prepare_loader(self, loader): + def _prepare_loader(self, loader: DataLoader) -> None: _add_capture_metadata_collate(loader) self._loader = loader self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher self.num_batches_fetched = 0 - self._sampler_state = [] + self._sampler_state: List[Tuple[Dict[int, Any], int]] = [] self._sampler_state_idx = 0 def __del__(self) -> None: @@ -680,7 +704,7 @@ class _MultiProcessingDataLoaderIterStateful(_StatefulDataLoaderIter, _MultiProc super().__init__(loader) -def _get_iterator(self) -> "_BaseDataLoaderIter": +def _get_iterator(self: DataLoader) -> "_BaseDataLoaderIter": if not hasattr(self, "_lightning_fetcher"): raise MisconfigurationException( "A stateful iterator should be used only when a DataFetcher has been attached to the DataLoader." @@ -699,7 +723,7 @@ def _patch_dataloader_get_iterators() -> None: return if not hasattr(DataLoader, "_ori_get_iterator"): DataLoader._ori_get_iterator = DataLoader._get_iterator - DataLoader._get_iterator = _get_iterator + DataLoader._get_iterator = _get_iterator # type: ignore[assignment] def _teardown_dataloader_get_iterators() -> None: @@ -707,7 +731,7 @@ def _teardown_dataloader_get_iterators() -> None: # cleanup the get_iterator replacement in case of Fault-tolerance. get_iterator = getattr(DataLoader, "_ori_get_iterator", None) if get_iterator: - DataLoader._get_iterator = get_iterator + DataLoader._get_iterator = get_iterator # type: ignore[assignment] del DataLoader._ori_get_iterator @@ -781,16 +805,17 @@ def _validate_fault_tolerant_automatic(dataloader: Iterable, stage: "pl.trainer. raise ValueError("Fault-tolerance supports only a single dataloader.") for dataloader in dl_loaders: + assert isinstance(dataloader, DataLoader) validator_fn = ( _validate_iterable_dataset if isinstance(dataloader.dataset, IterableDataset) else _validate_map_dataset ) validator_fn(dataloader) -def _collect_states_on_rank_zero_over_collection(state_dict: Any, key: str = "state") -> Any: +def _collect_states_on_rank_zero_over_collection(state_dict: Dict, key: str = "state") -> Dict: """This utility collects the state across processes for a collection of state.""" - def fn(state: Dict): + def fn(state: Dict) -> Dict: if key in state: return _collect_states_on_rank_zero(state) return {k: apply_to_collection(v, Dict, fn) for k, v in state.items()} diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index 7ab3d69488..c90657b34e 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -20,7 +20,7 @@ from argparse import _ArgumentGroup, ArgumentParser from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, TypeVar, Union import torch from torch import Tensor @@ -90,21 +90,24 @@ class PredictStep(Protocol): ... +_DictKey = TypeVar("_DictKey") + + @runtime_checkable -class _Stateful(Protocol): +class _Stateful(Protocol[_DictKey]): """This class is used to detect if an object is stateful using `isinstance(obj, _Stateful)`.""" - def state_dict(self) -> Dict[str, Any]: + def state_dict(self) -> Dict[_DictKey, Any]: ... - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + def load_state_dict(self, state_dict: Dict[_DictKey, Any]) -> None: ... # Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing @runtime_checkable -class _LRScheduler(_Stateful, Protocol): +class _LRScheduler(_Stateful[str], Protocol): optimizer: Optimizer base_lrs: List[float] @@ -118,7 +121,7 @@ class _LRScheduler(_Stateful, Protocol): # Inferred from `torch.optim.lr_scheduler.pyi` # Missing attributes were added to improve typing @runtime_checkable -class ReduceLROnPlateau(_Stateful, Protocol): +class ReduceLROnPlateau(_Stateful[str], Protocol): in_cooldown: bool optimizer: Optimizer