fixes typing errors in auto_restart.py (#13904)

Co-authored-by: otaj <6065855+otaj@users.noreply.github.com>
Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
donlapark 2022-09-06 01:09:20 +07:00 committed by GitHub
parent e5395de9d3
commit 381600dcc3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 63 additions and 36 deletions

View File

@ -56,7 +56,6 @@ module = [
"pytorch_lightning.trainer.supporters",
"pytorch_lightning.trainer.trainer",
"pytorch_lightning.tuner.batch_size_scaling",
"pytorch_lightning.utilities.auto_restart",
"pytorch_lightning.utilities.data",
]
ignore_errors = "True"

View File

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sized
from copy import deepcopy
from dataclasses import dataclass, field
from functools import partial, wraps
@ -24,6 +25,7 @@ from torch.utils.data.dataloader import (
DataLoader,
IterableDataset,
)
from typing_extensions import TypedDict
import pytorch_lightning as pl
from pytorch_lightning.utilities.apply_func import apply_to_collection
@ -34,6 +36,21 @@ from pytorch_lightning.utilities.seed import _collect_rng_states, _set_rng_state
from pytorch_lightning.utilities.types import _Stateful
class _IteratorStateDict(TypedDict):
dataset_state: Dict[int, Any]
sampler_state: Dict[int, Any]
worker_id: int
num_workers: int
num_batches_fetched: int
name: Optional[str]
class _MergedIteratorStateDict(TypedDict):
state: Dict[str, Any]
latest_worker_id: int
represent_map_dataset: Optional[bool]
class FastForwardSampler(Sampler):
"""This FastForwardSampler wraps a :class:`torch.utils.data.Sampler` and records the number of iterations
performed during an epoch.
@ -45,7 +62,7 @@ class FastForwardSampler(Sampler):
samples seen in the last iterations (for the current worker).
"""
def __init__(self, sampler: Union[Sampler, Generator], attr_name: Optional[str] = None) -> None:
def __init__(self, sampler: Iterator, attr_name: Optional[str] = None) -> None:
super().__init__(data_source=None)
self._sampler = sampler
self.restarting: bool = False
@ -79,7 +96,7 @@ class FastForwardSampler(Sampler):
self._counter = 0
return self
def __next__(self):
def __next__(self) -> Any:
# the `state dict` was cached as workers were unavailable before.
if self._cached_state_dict is not None:
self._load_non_random_state(self._cached_state_dict)
@ -109,6 +126,7 @@ class FastForwardSampler(Sampler):
raise StopIteration
def __len__(self) -> int:
assert isinstance(self._sampler, Sized)
return len(self._sampler)
def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, int]]:
@ -161,7 +179,7 @@ class IteratorState:
name: Optional[str] = None
@classmethod
def from_state_dict(cls, state_dict) -> "IteratorState":
def from_state_dict(cls, state_dict: _IteratorStateDict) -> "IteratorState":
return cls(**state_dict)
@ -173,22 +191,22 @@ class MergedIteratorState:
worker states in this merged iterator state.
"""
state: Union[Dict[Union[int, str], Union[Dict[str, IteratorState], IteratorState]]] = field(default_factory=dict)
state: Dict = field(default_factory=dict)
latest_worker_id: int = 0
represent_map_dataset: Optional[bool] = None
def update(self, generator_name: Optional[str], new_state: IteratorState) -> None:
# a map based dataset doesn't own a generator and therefore `generator_name` should be None.
self.represent_map_dataset = generator_name is None
if self.represent_map_dataset:
state = self.state
latest_worker_id = new_state.worker_id
if generator_name is None:
self.state[latest_worker_id] = new_state
else:
if generator_name not in self.state:
self.state[generator_name] = {}
state = self.state[generator_name]
state[latest_worker_id] = new_state
latest_worker_id = new_state.worker_id
state[latest_worker_id] = new_state
self.latest_worker_id = latest_worker_id
@property
@ -202,7 +220,7 @@ class MergedIteratorState:
return {k: self.state[k].dataset_state[k] for k in self.state.keys()}
@classmethod
def from_state_dict(cls, state_dict) -> "MergedIteratorState":
def from_state_dict(cls, state_dict: _MergedIteratorStateDict) -> "MergedIteratorState":
if state_dict["represent_map_dataset"]:
state_dict["state"] = {
worker_id: IteratorState.from_state_dict(state) for worker_id, state in state_dict["state"].items()
@ -229,15 +247,15 @@ class CaptureMapDataset(Dataset):
"""
def __init__(self, dataset: Dataset) -> None:
self.dataset = dataset
self._cached_state_dict = None
self.dataset: Dataset = dataset
self._cached_state_dict: Optional[Dict[int, Any]] = None
@property
def worker_id(self) -> int:
worker_info = get_worker_info()
return worker_info.id if worker_info else 0
def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]:
def __getitem__(self, item: int) -> Tuple[Any, Dict[int, Dict]]:
if self._cached_state_dict is not None:
if self.worker_id in self._cached_state_dict:
_set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"])
@ -246,6 +264,7 @@ class CaptureMapDataset(Dataset):
return self.dataset[item]
def __len__(self) -> int:
assert isinstance(self.dataset, Sized)
return len(self.dataset)
def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None:
@ -268,7 +287,7 @@ class CaptureIterableDataset(IterableDataset):
super().__init__()
self.dataset = deepcopy(dataset)
self.samplers: Optional[Dict[str, FastForwardSampler]] = None
self._state_dict: Optional[Dict[int, Any]] = None
self._state_dict: Optional[Dict[str, Any]] = None
self._has_wrapped: bool = False
@property
@ -276,9 +295,10 @@ class CaptureIterableDataset(IterableDataset):
return self.dataset.sampler
def state_dict(self) -> Dict[str, Any]:
assert self.samplers is not None
return {k: v.state_dict() for k, v in self.samplers.items()}
def load_state_dict(self, state_dict: Dict[int, Any]) -> None:
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self._state_dict = deepcopy(state_dict)
def _wrap_generator_samplers(self) -> None:
@ -311,7 +331,7 @@ class CaptureIterableDataset(IterableDataset):
self.reset_on_epoch()
def reset_on_epoch(self):
def reset_on_epoch(self) -> None:
self._state_dict = None
def __iter__(self) -> Iterator:
@ -371,8 +391,8 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str
for _ in range(state_dict["previous_worker"] - 1):
next(iter_dataloader._worker_queue_idx_cycle)
# we can finally call reset and apply prefecthing.
iter_dataloader._reset = iter_dataloader._original_reset
# we can finally call reset and apply prefetching.
iter_dataloader._reset = iter_dataloader._original_reset # type: ignore[assignment]
iter_dataloader._reset(dataloader, first_iter=True)
# return the iterator
return iter_dataloader
@ -445,6 +465,7 @@ def _next_data_wrapper(
]
elif isinstance(dataset, CaptureMapDataset):
ff_sampler = _find_fast_forward_samplers(dl)
assert ff_sampler is not None
state = [
IteratorState(
num_workers=dl.num_workers,
@ -519,6 +540,7 @@ def _reload_dataloader_state_dict_automatic_map_dataset(dataloader: DataLoader,
# reload sampler state
ff_sampler = _find_fast_forward_samplers(dataloader)
assert ff_sampler is not None
ff_sampler.load_state_dict(iterator_state.sampler_state)
# reload dataset state
@ -610,18 +632,20 @@ def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_wor
return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state}
class _StatefulDataLoaderIter:
class _StatefulDataLoaderIter(_BaseDataLoaderIter):
"""This mixin is used to make PyTorch DataLoaderIter stateful."""
def __accumulate_state(self, sampler_state: Dict[str, Any]) -> None:
def __accumulate_state(self, sampler_state: Dict[int, Any]) -> None:
# store sampler state within a queue alongside its idx.
self._sampler_state_idx = getattr(self, "_sampler_state_idx", 0) + 1
self._sampler_state_idx: int = getattr(self, "_sampler_state_idx", 0) + 1
self._sampler_state.append((sampler_state, self._sampler_state_idx))
def _store_sampler_state(self) -> None:
"""This function is used to extract the sampler states if any."""
sampler_state = {
k: v.state_dict() for k, v in self._loader.__dict__.items() if isinstance(v, _Stateful) and k != "dataset"
sampler_state: Dict[int, Any] = {
k: v.state_dict() # type: ignore[misc]
for k, v in self._loader.__dict__.items()
if isinstance(v, _Stateful) and k != "dataset"
}
self.__accumulate_state(sampler_state)
@ -630,12 +654,12 @@ class _StatefulDataLoaderIter:
self._store_sampler_state()
return indexes
def _prepare_loader(self, loader):
def _prepare_loader(self, loader: DataLoader) -> None:
_add_capture_metadata_collate(loader)
self._loader = loader
self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher
self.num_batches_fetched = 0
self._sampler_state = []
self._sampler_state: List[Tuple[Dict[int, Any], int]] = []
self._sampler_state_idx = 0
def __del__(self) -> None:
@ -680,7 +704,7 @@ class _MultiProcessingDataLoaderIterStateful(_StatefulDataLoaderIter, _MultiProc
super().__init__(loader)
def _get_iterator(self) -> "_BaseDataLoaderIter":
def _get_iterator(self: DataLoader) -> "_BaseDataLoaderIter":
if not hasattr(self, "_lightning_fetcher"):
raise MisconfigurationException(
"A stateful iterator should be used only when a DataFetcher has been attached to the DataLoader."
@ -699,7 +723,7 @@ def _patch_dataloader_get_iterators() -> None:
return
if not hasattr(DataLoader, "_ori_get_iterator"):
DataLoader._ori_get_iterator = DataLoader._get_iterator
DataLoader._get_iterator = _get_iterator
DataLoader._get_iterator = _get_iterator # type: ignore[assignment]
def _teardown_dataloader_get_iterators() -> None:
@ -707,7 +731,7 @@ def _teardown_dataloader_get_iterators() -> None:
# cleanup the get_iterator replacement in case of Fault-tolerance.
get_iterator = getattr(DataLoader, "_ori_get_iterator", None)
if get_iterator:
DataLoader._get_iterator = get_iterator
DataLoader._get_iterator = get_iterator # type: ignore[assignment]
del DataLoader._ori_get_iterator
@ -781,16 +805,17 @@ def _validate_fault_tolerant_automatic(dataloader: Iterable, stage: "pl.trainer.
raise ValueError("Fault-tolerance supports only a single dataloader.")
for dataloader in dl_loaders:
assert isinstance(dataloader, DataLoader)
validator_fn = (
_validate_iterable_dataset if isinstance(dataloader.dataset, IterableDataset) else _validate_map_dataset
)
validator_fn(dataloader)
def _collect_states_on_rank_zero_over_collection(state_dict: Any, key: str = "state") -> Any:
def _collect_states_on_rank_zero_over_collection(state_dict: Dict, key: str = "state") -> Dict:
"""This utility collects the state across processes for a collection of state."""
def fn(state: Dict):
def fn(state: Dict) -> Dict:
if key in state:
return _collect_states_on_rank_zero(state)
return {k: apply_to_collection(v, Dict, fn) for k, v in state.items()}

View File

@ -20,7 +20,7 @@ from argparse import _ArgumentGroup, ArgumentParser
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union
from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, TypeVar, Union
import torch
from torch import Tensor
@ -90,21 +90,24 @@ class PredictStep(Protocol):
...
_DictKey = TypeVar("_DictKey")
@runtime_checkable
class _Stateful(Protocol):
class _Stateful(Protocol[_DictKey]):
"""This class is used to detect if an object is stateful using `isinstance(obj, _Stateful)`."""
def state_dict(self) -> Dict[str, Any]:
def state_dict(self) -> Dict[_DictKey, Any]:
...
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
def load_state_dict(self, state_dict: Dict[_DictKey, Any]) -> None:
...
# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
@runtime_checkable
class _LRScheduler(_Stateful, Protocol):
class _LRScheduler(_Stateful[str], Protocol):
optimizer: Optimizer
base_lrs: List[float]
@ -118,7 +121,7 @@ class _LRScheduler(_Stateful, Protocol):
# Inferred from `torch.optim.lr_scheduler.pyi`
# Missing attributes were added to improve typing
@runtime_checkable
class ReduceLROnPlateau(_Stateful, Protocol):
class ReduceLROnPlateau(_Stateful[str], Protocol):
in_cooldown: bool
optimizer: Optimizer