fixes typing errors in auto_restart.py (#13904)
Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> Co-authored-by: rohitgr7 <rohitgr1998@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
e5395de9d3
commit
381600dcc3
|
@ -56,7 +56,6 @@ module = [
|
|||
"pytorch_lightning.trainer.supporters",
|
||||
"pytorch_lightning.trainer.trainer",
|
||||
"pytorch_lightning.tuner.batch_size_scaling",
|
||||
"pytorch_lightning.utilities.auto_restart",
|
||||
"pytorch_lightning.utilities.data",
|
||||
]
|
||||
ignore_errors = "True"
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Sized
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial, wraps
|
||||
|
@ -24,6 +25,7 @@ from torch.utils.data.dataloader import (
|
|||
DataLoader,
|
||||
IterableDataset,
|
||||
)
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
@ -34,6 +36,21 @@ from pytorch_lightning.utilities.seed import _collect_rng_states, _set_rng_state
|
|||
from pytorch_lightning.utilities.types import _Stateful
|
||||
|
||||
|
||||
class _IteratorStateDict(TypedDict):
|
||||
dataset_state: Dict[int, Any]
|
||||
sampler_state: Dict[int, Any]
|
||||
worker_id: int
|
||||
num_workers: int
|
||||
num_batches_fetched: int
|
||||
name: Optional[str]
|
||||
|
||||
|
||||
class _MergedIteratorStateDict(TypedDict):
|
||||
state: Dict[str, Any]
|
||||
latest_worker_id: int
|
||||
represent_map_dataset: Optional[bool]
|
||||
|
||||
|
||||
class FastForwardSampler(Sampler):
|
||||
"""This FastForwardSampler wraps a :class:`torch.utils.data.Sampler` and records the number of iterations
|
||||
performed during an epoch.
|
||||
|
@ -45,7 +62,7 @@ class FastForwardSampler(Sampler):
|
|||
samples seen in the last iterations (for the current worker).
|
||||
"""
|
||||
|
||||
def __init__(self, sampler: Union[Sampler, Generator], attr_name: Optional[str] = None) -> None:
|
||||
def __init__(self, sampler: Iterator, attr_name: Optional[str] = None) -> None:
|
||||
super().__init__(data_source=None)
|
||||
self._sampler = sampler
|
||||
self.restarting: bool = False
|
||||
|
@ -79,7 +96,7 @@ class FastForwardSampler(Sampler):
|
|||
self._counter = 0
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
def __next__(self) -> Any:
|
||||
# the `state dict` was cached as workers were unavailable before.
|
||||
if self._cached_state_dict is not None:
|
||||
self._load_non_random_state(self._cached_state_dict)
|
||||
|
@ -109,6 +126,7 @@ class FastForwardSampler(Sampler):
|
|||
raise StopIteration
|
||||
|
||||
def __len__(self) -> int:
|
||||
assert isinstance(self._sampler, Sized)
|
||||
return len(self._sampler)
|
||||
|
||||
def state_dict(self, num_batches_processed: Optional[int] = None) -> Dict[int, Dict[str, int]]:
|
||||
|
@ -161,7 +179,7 @@ class IteratorState:
|
|||
name: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict) -> "IteratorState":
|
||||
def from_state_dict(cls, state_dict: _IteratorStateDict) -> "IteratorState":
|
||||
return cls(**state_dict)
|
||||
|
||||
|
||||
|
@ -173,22 +191,22 @@ class MergedIteratorState:
|
|||
worker states in this merged iterator state.
|
||||
"""
|
||||
|
||||
state: Union[Dict[Union[int, str], Union[Dict[str, IteratorState], IteratorState]]] = field(default_factory=dict)
|
||||
state: Dict = field(default_factory=dict)
|
||||
latest_worker_id: int = 0
|
||||
represent_map_dataset: Optional[bool] = None
|
||||
|
||||
def update(self, generator_name: Optional[str], new_state: IteratorState) -> None:
|
||||
# a map based dataset doesn't own a generator and therefore `generator_name` should be None.
|
||||
self.represent_map_dataset = generator_name is None
|
||||
if self.represent_map_dataset:
|
||||
state = self.state
|
||||
latest_worker_id = new_state.worker_id
|
||||
if generator_name is None:
|
||||
self.state[latest_worker_id] = new_state
|
||||
else:
|
||||
if generator_name not in self.state:
|
||||
self.state[generator_name] = {}
|
||||
state = self.state[generator_name]
|
||||
state[latest_worker_id] = new_state
|
||||
|
||||
latest_worker_id = new_state.worker_id
|
||||
state[latest_worker_id] = new_state
|
||||
self.latest_worker_id = latest_worker_id
|
||||
|
||||
@property
|
||||
|
@ -202,7 +220,7 @@ class MergedIteratorState:
|
|||
return {k: self.state[k].dataset_state[k] for k in self.state.keys()}
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict) -> "MergedIteratorState":
|
||||
def from_state_dict(cls, state_dict: _MergedIteratorStateDict) -> "MergedIteratorState":
|
||||
if state_dict["represent_map_dataset"]:
|
||||
state_dict["state"] = {
|
||||
worker_id: IteratorState.from_state_dict(state) for worker_id, state in state_dict["state"].items()
|
||||
|
@ -229,15 +247,15 @@ class CaptureMapDataset(Dataset):
|
|||
"""
|
||||
|
||||
def __init__(self, dataset: Dataset) -> None:
|
||||
self.dataset = dataset
|
||||
self._cached_state_dict = None
|
||||
self.dataset: Dataset = dataset
|
||||
self._cached_state_dict: Optional[Dict[int, Any]] = None
|
||||
|
||||
@property
|
||||
def worker_id(self) -> int:
|
||||
worker_info = get_worker_info()
|
||||
return worker_info.id if worker_info else 0
|
||||
|
||||
def __getitem__(self, item) -> Tuple[Any, Dict[int, Dict]]:
|
||||
def __getitem__(self, item: int) -> Tuple[Any, Dict[int, Dict]]:
|
||||
if self._cached_state_dict is not None:
|
||||
if self.worker_id in self._cached_state_dict:
|
||||
_set_rng_states(self._cached_state_dict[self.worker_id]["rng_states"])
|
||||
|
@ -246,6 +264,7 @@ class CaptureMapDataset(Dataset):
|
|||
return self.dataset[item]
|
||||
|
||||
def __len__(self) -> int:
|
||||
assert isinstance(self.dataset, Sized)
|
||||
return len(self.dataset)
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None:
|
||||
|
@ -268,7 +287,7 @@ class CaptureIterableDataset(IterableDataset):
|
|||
super().__init__()
|
||||
self.dataset = deepcopy(dataset)
|
||||
self.samplers: Optional[Dict[str, FastForwardSampler]] = None
|
||||
self._state_dict: Optional[Dict[int, Any]] = None
|
||||
self._state_dict: Optional[Dict[str, Any]] = None
|
||||
self._has_wrapped: bool = False
|
||||
|
||||
@property
|
||||
|
@ -276,9 +295,10 @@ class CaptureIterableDataset(IterableDataset):
|
|||
return self.dataset.sampler
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
assert self.samplers is not None
|
||||
return {k: v.state_dict() for k, v in self.samplers.items()}
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[int, Any]) -> None:
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
self._state_dict = deepcopy(state_dict)
|
||||
|
||||
def _wrap_generator_samplers(self) -> None:
|
||||
|
@ -311,7 +331,7 @@ class CaptureIterableDataset(IterableDataset):
|
|||
|
||||
self.reset_on_epoch()
|
||||
|
||||
def reset_on_epoch(self):
|
||||
def reset_on_epoch(self) -> None:
|
||||
self._state_dict = None
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
|
@ -371,8 +391,8 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str
|
|||
for _ in range(state_dict["previous_worker"] - 1):
|
||||
next(iter_dataloader._worker_queue_idx_cycle)
|
||||
|
||||
# we can finally call reset and apply prefecthing.
|
||||
iter_dataloader._reset = iter_dataloader._original_reset
|
||||
# we can finally call reset and apply prefetching.
|
||||
iter_dataloader._reset = iter_dataloader._original_reset # type: ignore[assignment]
|
||||
iter_dataloader._reset(dataloader, first_iter=True)
|
||||
# return the iterator
|
||||
return iter_dataloader
|
||||
|
@ -445,6 +465,7 @@ def _next_data_wrapper(
|
|||
]
|
||||
elif isinstance(dataset, CaptureMapDataset):
|
||||
ff_sampler = _find_fast_forward_samplers(dl)
|
||||
assert ff_sampler is not None
|
||||
state = [
|
||||
IteratorState(
|
||||
num_workers=dl.num_workers,
|
||||
|
@ -519,6 +540,7 @@ def _reload_dataloader_state_dict_automatic_map_dataset(dataloader: DataLoader,
|
|||
|
||||
# reload sampler state
|
||||
ff_sampler = _find_fast_forward_samplers(dataloader)
|
||||
assert ff_sampler is not None
|
||||
ff_sampler.load_state_dict(iterator_state.sampler_state)
|
||||
|
||||
# reload dataset state
|
||||
|
@ -610,18 +632,20 @@ def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_wor
|
|||
return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state}
|
||||
|
||||
|
||||
class _StatefulDataLoaderIter:
|
||||
class _StatefulDataLoaderIter(_BaseDataLoaderIter):
|
||||
"""This mixin is used to make PyTorch DataLoaderIter stateful."""
|
||||
|
||||
def __accumulate_state(self, sampler_state: Dict[str, Any]) -> None:
|
||||
def __accumulate_state(self, sampler_state: Dict[int, Any]) -> None:
|
||||
# store sampler state within a queue alongside its idx.
|
||||
self._sampler_state_idx = getattr(self, "_sampler_state_idx", 0) + 1
|
||||
self._sampler_state_idx: int = getattr(self, "_sampler_state_idx", 0) + 1
|
||||
self._sampler_state.append((sampler_state, self._sampler_state_idx))
|
||||
|
||||
def _store_sampler_state(self) -> None:
|
||||
"""This function is used to extract the sampler states if any."""
|
||||
sampler_state = {
|
||||
k: v.state_dict() for k, v in self._loader.__dict__.items() if isinstance(v, _Stateful) and k != "dataset"
|
||||
sampler_state: Dict[int, Any] = {
|
||||
k: v.state_dict() # type: ignore[misc]
|
||||
for k, v in self._loader.__dict__.items()
|
||||
if isinstance(v, _Stateful) and k != "dataset"
|
||||
}
|
||||
self.__accumulate_state(sampler_state)
|
||||
|
||||
|
@ -630,12 +654,12 @@ class _StatefulDataLoaderIter:
|
|||
self._store_sampler_state()
|
||||
return indexes
|
||||
|
||||
def _prepare_loader(self, loader):
|
||||
def _prepare_loader(self, loader: DataLoader) -> None:
|
||||
_add_capture_metadata_collate(loader)
|
||||
self._loader = loader
|
||||
self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher
|
||||
self.num_batches_fetched = 0
|
||||
self._sampler_state = []
|
||||
self._sampler_state: List[Tuple[Dict[int, Any], int]] = []
|
||||
self._sampler_state_idx = 0
|
||||
|
||||
def __del__(self) -> None:
|
||||
|
@ -680,7 +704,7 @@ class _MultiProcessingDataLoaderIterStateful(_StatefulDataLoaderIter, _MultiProc
|
|||
super().__init__(loader)
|
||||
|
||||
|
||||
def _get_iterator(self) -> "_BaseDataLoaderIter":
|
||||
def _get_iterator(self: DataLoader) -> "_BaseDataLoaderIter":
|
||||
if not hasattr(self, "_lightning_fetcher"):
|
||||
raise MisconfigurationException(
|
||||
"A stateful iterator should be used only when a DataFetcher has been attached to the DataLoader."
|
||||
|
@ -699,7 +723,7 @@ def _patch_dataloader_get_iterators() -> None:
|
|||
return
|
||||
if not hasattr(DataLoader, "_ori_get_iterator"):
|
||||
DataLoader._ori_get_iterator = DataLoader._get_iterator
|
||||
DataLoader._get_iterator = _get_iterator
|
||||
DataLoader._get_iterator = _get_iterator # type: ignore[assignment]
|
||||
|
||||
|
||||
def _teardown_dataloader_get_iterators() -> None:
|
||||
|
@ -707,7 +731,7 @@ def _teardown_dataloader_get_iterators() -> None:
|
|||
# cleanup the get_iterator replacement in case of Fault-tolerance.
|
||||
get_iterator = getattr(DataLoader, "_ori_get_iterator", None)
|
||||
if get_iterator:
|
||||
DataLoader._get_iterator = get_iterator
|
||||
DataLoader._get_iterator = get_iterator # type: ignore[assignment]
|
||||
del DataLoader._ori_get_iterator
|
||||
|
||||
|
||||
|
@ -781,16 +805,17 @@ def _validate_fault_tolerant_automatic(dataloader: Iterable, stage: "pl.trainer.
|
|||
raise ValueError("Fault-tolerance supports only a single dataloader.")
|
||||
|
||||
for dataloader in dl_loaders:
|
||||
assert isinstance(dataloader, DataLoader)
|
||||
validator_fn = (
|
||||
_validate_iterable_dataset if isinstance(dataloader.dataset, IterableDataset) else _validate_map_dataset
|
||||
)
|
||||
validator_fn(dataloader)
|
||||
|
||||
|
||||
def _collect_states_on_rank_zero_over_collection(state_dict: Any, key: str = "state") -> Any:
|
||||
def _collect_states_on_rank_zero_over_collection(state_dict: Dict, key: str = "state") -> Dict:
|
||||
"""This utility collects the state across processes for a collection of state."""
|
||||
|
||||
def fn(state: Dict):
|
||||
def fn(state: Dict) -> Dict:
|
||||
if key in state:
|
||||
return _collect_states_on_rank_zero(state)
|
||||
return {k: apply_to_collection(v, Dict, fn) for k, v in state.items()}
|
||||
|
|
|
@ -20,7 +20,7 @@ from argparse import _ArgumentGroup, ArgumentParser
|
|||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union
|
||||
from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
@ -90,21 +90,24 @@ class PredictStep(Protocol):
|
|||
...
|
||||
|
||||
|
||||
_DictKey = TypeVar("_DictKey")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class _Stateful(Protocol):
|
||||
class _Stateful(Protocol[_DictKey]):
|
||||
"""This class is used to detect if an object is stateful using `isinstance(obj, _Stateful)`."""
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
def state_dict(self) -> Dict[_DictKey, Any]:
|
||||
...
|
||||
|
||||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
||||
def load_state_dict(self, state_dict: Dict[_DictKey, Any]) -> None:
|
||||
...
|
||||
|
||||
|
||||
# Inferred from `torch.optim.lr_scheduler.pyi`
|
||||
# Missing attributes were added to improve typing
|
||||
@runtime_checkable
|
||||
class _LRScheduler(_Stateful, Protocol):
|
||||
class _LRScheduler(_Stateful[str], Protocol):
|
||||
optimizer: Optimizer
|
||||
base_lrs: List[float]
|
||||
|
||||
|
@ -118,7 +121,7 @@ class _LRScheduler(_Stateful, Protocol):
|
|||
# Inferred from `torch.optim.lr_scheduler.pyi`
|
||||
# Missing attributes were added to improve typing
|
||||
@runtime_checkable
|
||||
class ReduceLROnPlateau(_Stateful, Protocol):
|
||||
class ReduceLROnPlateau(_Stateful[str], Protocol):
|
||||
in_cooldown: bool
|
||||
optimizer: Optimizer
|
||||
|
||||
|
|
Loading…
Reference in New Issue