Add `@override` for files in `src/lightning/pytorch/loops` (#18966)
This commit is contained in:
parent
c524c0b01b
commit
2334a8acba
|
@ -14,6 +14,8 @@
|
|||
|
||||
from typing import Any, Iterator, List, Optional
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from lightning.fabric.utilities.data import sized_len
|
||||
from lightning.pytorch.utilities.combined_loader import _ITERATOR_RETURN, CombinedLoader
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
|
@ -44,11 +46,13 @@ class _DataFetcher(Iterator):
|
|||
def setup(self, combined_loader: CombinedLoader) -> None:
|
||||
self._combined_loader = combined_loader
|
||||
|
||||
@override
|
||||
def __iter__(self) -> "_DataFetcher":
|
||||
self.iterator = iter(self.combined_loader)
|
||||
self.reset()
|
||||
return self
|
||||
|
||||
@override
|
||||
def __next__(self) -> _ITERATOR_RETURN:
|
||||
assert self.iterator is not None
|
||||
self._start_profiler()
|
||||
|
@ -95,6 +99,7 @@ class _PrefetchDataFetcher(_DataFetcher):
|
|||
self.prefetch_batches = prefetch_batches
|
||||
self.batches: List[Any] = []
|
||||
|
||||
@override
|
||||
def __iter__(self) -> "_PrefetchDataFetcher":
|
||||
super().__iter__()
|
||||
if self.length is not None:
|
||||
|
@ -111,6 +116,7 @@ class _PrefetchDataFetcher(_DataFetcher):
|
|||
break
|
||||
return self
|
||||
|
||||
@override
|
||||
def __next__(self) -> _ITERATOR_RETURN:
|
||||
if self.batches:
|
||||
# there are pre-fetched batches already from a previous `prefetching` call.
|
||||
|
@ -130,6 +136,7 @@ class _PrefetchDataFetcher(_DataFetcher):
|
|||
raise StopIteration
|
||||
return batch
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
super().reset()
|
||||
self.batches = []
|
||||
|
@ -156,16 +163,19 @@ class _DataLoaderIterDataFetcher(_DataFetcher):
|
|||
self._batch_idx: int = 0
|
||||
self._dataloader_idx: int = 0
|
||||
|
||||
@override
|
||||
def __iter__(self) -> "_DataLoaderIterDataFetcher":
|
||||
super().__iter__()
|
||||
self.iterator_wrapper = iter(_DataFetcherWrapper(self))
|
||||
return self
|
||||
|
||||
@override
|
||||
def __next__(self) -> Iterator["_DataFetcherWrapper"]: # type: ignore[override]
|
||||
if self.done:
|
||||
raise StopIteration
|
||||
return self.iterator_wrapper
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
super().reset()
|
||||
self._batch = None
|
||||
|
@ -189,6 +199,7 @@ class _DataFetcherWrapper(Iterator):
|
|||
def length(self) -> Optional[int]:
|
||||
return self.data_fetcher.length
|
||||
|
||||
@override
|
||||
def __next__(self) -> _ITERATOR_RETURN:
|
||||
fetcher = self.data_fetcher
|
||||
if fetcher.done:
|
||||
|
|
|
@ -15,6 +15,7 @@ import logging
|
|||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.data import _set_sampler_epoch, sized_len
|
||||
|
@ -117,6 +118,7 @@ class _FitLoop(_Loop):
|
|||
return self.epoch_loop.max_steps
|
||||
|
||||
@_Loop.restarting.setter
|
||||
@override
|
||||
def restarting(self, restarting: bool) -> None:
|
||||
# if the last epoch completely finished, we are not actually restarting
|
||||
values = self.epoch_progress.current.ready, self.epoch_progress.current.started
|
||||
|
|
|
@ -18,6 +18,7 @@ from typing import Any, Callable, Dict, Mapping, Optional, OrderedDict
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
from typing_extensions import override
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.loops.loop import _Loop
|
||||
|
@ -81,6 +82,7 @@ class ClosureResult(OutputResult):
|
|||
|
||||
return cls(closure_loss, extra=extra)
|
||||
|
||||
@override
|
||||
def asdict(self) -> Dict[str, Any]:
|
||||
return {"loss": self.loss, **self.extra}
|
||||
|
||||
|
@ -121,6 +123,7 @@ class Closure(AbstractClosure[ClosureResult]):
|
|||
self._backward_fn = backward_fn
|
||||
self._zero_grad_fn = zero_grad_fn
|
||||
|
||||
@override
|
||||
@torch.enable_grad()
|
||||
def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
|
||||
step_output = self._step_fn()
|
||||
|
@ -136,6 +139,7 @@ class Closure(AbstractClosure[ClosureResult]):
|
|||
|
||||
return step_output
|
||||
|
||||
@override
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
|
||||
self._result = self.closure(*args, **kwargs)
|
||||
return self._result.loss
|
||||
|
|
|
@ -17,6 +17,7 @@ from dataclasses import dataclass, field
|
|||
from typing import Any, Dict
|
||||
|
||||
from torch import Tensor
|
||||
from typing_extensions import override
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch.core.optimizer import do_nothing_closure
|
||||
|
@ -59,6 +60,7 @@ class ManualResult(OutputResult):
|
|||
|
||||
return cls(extra=extra)
|
||||
|
||||
@override
|
||||
def asdict(self) -> Dict[str, Any]:
|
||||
return self.extra
|
||||
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Type
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BaseProgress:
|
||||
|
@ -51,6 +53,7 @@ class _ReadyCompletedTracker(_BaseProgress):
|
|||
ready: int = 0
|
||||
completed: int = 0
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
"""Reset the state."""
|
||||
self.ready = 0
|
||||
|
@ -81,10 +84,12 @@ class _StartedTracker(_ReadyCompletedTracker):
|
|||
|
||||
started: int = 0
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
super().reset()
|
||||
self.started = 0
|
||||
|
||||
@override
|
||||
def reset_on_restart(self) -> None:
|
||||
super().reset_on_restart()
|
||||
self.started = self.completed
|
||||
|
@ -106,10 +111,12 @@ class _ProcessedTracker(_StartedTracker):
|
|||
|
||||
processed: int = 0
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
super().reset()
|
||||
self.processed = 0
|
||||
|
||||
@override
|
||||
def reset_on_restart(self) -> None:
|
||||
super().reset_on_restart()
|
||||
self.processed = self.completed
|
||||
|
@ -157,6 +164,7 @@ class _Progress(_BaseProgress):
|
|||
"""Utility function to easily create an instance from keyword arguments to both ``Tracker``s."""
|
||||
return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs))
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
self.total.reset()
|
||||
self.current.reset()
|
||||
|
@ -167,6 +175,7 @@ class _Progress(_BaseProgress):
|
|||
def reset_on_restart(self) -> None:
|
||||
self.current.reset_on_restart()
|
||||
|
||||
@override
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self.total.load_state_dict(state_dict["total"])
|
||||
self.current.load_state_dict(state_dict["current"])
|
||||
|
@ -187,14 +196,17 @@ class _BatchProgress(_Progress):
|
|||
|
||||
is_last_batch: bool = False
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
super().reset()
|
||||
self.is_last_batch = False
|
||||
|
||||
@override
|
||||
def reset_on_run(self) -> None:
|
||||
super().reset_on_run()
|
||||
self.is_last_batch = False
|
||||
|
||||
@override
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
super().load_state_dict(state_dict)
|
||||
self.is_last_batch = state_dict["is_last_batch"]
|
||||
|
@ -229,6 +241,7 @@ class _OptimizerProgress(_BaseProgress):
|
|||
step: _Progress = field(default_factory=lambda: _Progress.from_defaults(_ReadyCompletedTracker))
|
||||
zero_grad: _Progress = field(default_factory=lambda: _Progress.from_defaults(_StartedTracker))
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
self.step.reset()
|
||||
self.zero_grad.reset()
|
||||
|
@ -241,6 +254,7 @@ class _OptimizerProgress(_BaseProgress):
|
|||
self.step.reset_on_restart()
|
||||
self.zero_grad.reset_on_restart()
|
||||
|
||||
@override
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self.step.load_state_dict(state_dict["step"])
|
||||
self.zero_grad.load_state_dict(state_dict["zero_grad"])
|
||||
|
@ -261,6 +275,7 @@ class _OptimizationProgress(_BaseProgress):
|
|||
def optimizer_steps(self) -> int:
|
||||
return self.optimizer.step.total.completed
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
self.optimizer.reset()
|
||||
|
||||
|
@ -270,5 +285,6 @@ class _OptimizationProgress(_BaseProgress):
|
|||
def reset_on_restart(self) -> None:
|
||||
self.optimizer.reset_on_restart()
|
||||
|
||||
@override
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
self.optimizer.load_state_dict(state_dict["optimizer"])
|
||||
|
|
|
@ -15,6 +15,8 @@ import math
|
|||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.pytorch import loops # import as loops to avoid circular imports
|
||||
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher
|
||||
|
@ -304,11 +306,13 @@ class _TrainingEpochLoop(loops._Loop):
|
|||
self._results.cpu()
|
||||
self.val_loop.teardown()
|
||||
|
||||
@override
|
||||
def on_save_checkpoint(self) -> Dict:
|
||||
state_dict = super().on_save_checkpoint()
|
||||
state_dict["_batches_that_stepped"] = self._batches_that_stepped
|
||||
return state_dict
|
||||
|
||||
@override
|
||||
def on_load_checkpoint(self, state_dict: Dict) -> None:
|
||||
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)
|
||||
|
||||
|
|
Loading…
Reference in New Issue