diff --git a/src/lightning/pytorch/loops/fetchers.py b/src/lightning/pytorch/loops/fetchers.py index 954b1ab574..e699321a4d 100644 --- a/src/lightning/pytorch/loops/fetchers.py +++ b/src/lightning/pytorch/loops/fetchers.py @@ -14,6 +14,8 @@ from typing import Any, Iterator, List, Optional +from typing_extensions import override + from lightning.fabric.utilities.data import sized_len from lightning.pytorch.utilities.combined_loader import _ITERATOR_RETURN, CombinedLoader from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -44,11 +46,13 @@ class _DataFetcher(Iterator): def setup(self, combined_loader: CombinedLoader) -> None: self._combined_loader = combined_loader + @override def __iter__(self) -> "_DataFetcher": self.iterator = iter(self.combined_loader) self.reset() return self + @override def __next__(self) -> _ITERATOR_RETURN: assert self.iterator is not None self._start_profiler() @@ -95,6 +99,7 @@ class _PrefetchDataFetcher(_DataFetcher): self.prefetch_batches = prefetch_batches self.batches: List[Any] = [] + @override def __iter__(self) -> "_PrefetchDataFetcher": super().__iter__() if self.length is not None: @@ -111,6 +116,7 @@ class _PrefetchDataFetcher(_DataFetcher): break return self + @override def __next__(self) -> _ITERATOR_RETURN: if self.batches: # there are pre-fetched batches already from a previous `prefetching` call. @@ -130,6 +136,7 @@ class _PrefetchDataFetcher(_DataFetcher): raise StopIteration return batch + @override def reset(self) -> None: super().reset() self.batches = [] @@ -156,16 +163,19 @@ class _DataLoaderIterDataFetcher(_DataFetcher): self._batch_idx: int = 0 self._dataloader_idx: int = 0 + @override def __iter__(self) -> "_DataLoaderIterDataFetcher": super().__iter__() self.iterator_wrapper = iter(_DataFetcherWrapper(self)) return self + @override def __next__(self) -> Iterator["_DataFetcherWrapper"]: # type: ignore[override] if self.done: raise StopIteration return self.iterator_wrapper + @override def reset(self) -> None: super().reset() self._batch = None @@ -189,6 +199,7 @@ class _DataFetcherWrapper(Iterator): def length(self) -> Optional[int]: return self.data_fetcher.length + @override def __next__(self) -> _ITERATOR_RETURN: fetcher = self.data_fetcher if fetcher.done: diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 5a5165f78a..db9345f1ab 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -15,6 +15,7 @@ import logging from typing import Optional, Union import torch +from typing_extensions import override import lightning.pytorch as pl from lightning.fabric.utilities.data import _set_sampler_epoch, sized_len @@ -117,6 +118,7 @@ class _FitLoop(_Loop): return self.epoch_loop.max_steps @_Loop.restarting.setter + @override def restarting(self, restarting: bool) -> None: # if the last epoch completely finished, we are not actually restarting values = self.epoch_progress.current.ready, self.epoch_progress.current.started diff --git a/src/lightning/pytorch/loops/optimization/automatic.py b/src/lightning/pytorch/loops/optimization/automatic.py index 4648bf53e3..450e8a0c60 100644 --- a/src/lightning/pytorch/loops/optimization/automatic.py +++ b/src/lightning/pytorch/loops/optimization/automatic.py @@ -18,6 +18,7 @@ from typing import Any, Callable, Dict, Mapping, Optional, OrderedDict import torch from torch import Tensor from torch.optim import Optimizer +from typing_extensions import override import lightning.pytorch as pl from lightning.pytorch.loops.loop import _Loop @@ -81,6 +82,7 @@ class ClosureResult(OutputResult): return cls(closure_loss, extra=extra) + @override def asdict(self) -> Dict[str, Any]: return {"loss": self.loss, **self.extra} @@ -121,6 +123,7 @@ class Closure(AbstractClosure[ClosureResult]): self._backward_fn = backward_fn self._zero_grad_fn = zero_grad_fn + @override @torch.enable_grad() def closure(self, *args: Any, **kwargs: Any) -> ClosureResult: step_output = self._step_fn() @@ -136,6 +139,7 @@ class Closure(AbstractClosure[ClosureResult]): return step_output + @override def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]: self._result = self.closure(*args, **kwargs) return self._result.loss diff --git a/src/lightning/pytorch/loops/optimization/manual.py b/src/lightning/pytorch/loops/optimization/manual.py index 31e5569ab7..d8a4f1968c 100644 --- a/src/lightning/pytorch/loops/optimization/manual.py +++ b/src/lightning/pytorch/loops/optimization/manual.py @@ -17,6 +17,7 @@ from dataclasses import dataclass, field from typing import Any, Dict from torch import Tensor +from typing_extensions import override import lightning.pytorch as pl from lightning.pytorch.core.optimizer import do_nothing_closure @@ -59,6 +60,7 @@ class ManualResult(OutputResult): return cls(extra=extra) + @override def asdict(self) -> Dict[str, Any]: return self.extra diff --git a/src/lightning/pytorch/loops/progress.py b/src/lightning/pytorch/loops/progress.py index 8ff12ba378..3d34653122 100644 --- a/src/lightning/pytorch/loops/progress.py +++ b/src/lightning/pytorch/loops/progress.py @@ -14,6 +14,8 @@ from dataclasses import asdict, dataclass, field from typing import Type +from typing_extensions import override + @dataclass class _BaseProgress: @@ -51,6 +53,7 @@ class _ReadyCompletedTracker(_BaseProgress): ready: int = 0 completed: int = 0 + @override def reset(self) -> None: """Reset the state.""" self.ready = 0 @@ -81,10 +84,12 @@ class _StartedTracker(_ReadyCompletedTracker): started: int = 0 + @override def reset(self) -> None: super().reset() self.started = 0 + @override def reset_on_restart(self) -> None: super().reset_on_restart() self.started = self.completed @@ -106,10 +111,12 @@ class _ProcessedTracker(_StartedTracker): processed: int = 0 + @override def reset(self) -> None: super().reset() self.processed = 0 + @override def reset_on_restart(self) -> None: super().reset_on_restart() self.processed = self.completed @@ -157,6 +164,7 @@ class _Progress(_BaseProgress): """Utility function to easily create an instance from keyword arguments to both ``Tracker``s.""" return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs)) + @override def reset(self) -> None: self.total.reset() self.current.reset() @@ -167,6 +175,7 @@ class _Progress(_BaseProgress): def reset_on_restart(self) -> None: self.current.reset_on_restart() + @override def load_state_dict(self, state_dict: dict) -> None: self.total.load_state_dict(state_dict["total"]) self.current.load_state_dict(state_dict["current"]) @@ -187,14 +196,17 @@ class _BatchProgress(_Progress): is_last_batch: bool = False + @override def reset(self) -> None: super().reset() self.is_last_batch = False + @override def reset_on_run(self) -> None: super().reset_on_run() self.is_last_batch = False + @override def load_state_dict(self, state_dict: dict) -> None: super().load_state_dict(state_dict) self.is_last_batch = state_dict["is_last_batch"] @@ -229,6 +241,7 @@ class _OptimizerProgress(_BaseProgress): step: _Progress = field(default_factory=lambda: _Progress.from_defaults(_ReadyCompletedTracker)) zero_grad: _Progress = field(default_factory=lambda: _Progress.from_defaults(_StartedTracker)) + @override def reset(self) -> None: self.step.reset() self.zero_grad.reset() @@ -241,6 +254,7 @@ class _OptimizerProgress(_BaseProgress): self.step.reset_on_restart() self.zero_grad.reset_on_restart() + @override def load_state_dict(self, state_dict: dict) -> None: self.step.load_state_dict(state_dict["step"]) self.zero_grad.load_state_dict(state_dict["zero_grad"]) @@ -261,6 +275,7 @@ class _OptimizationProgress(_BaseProgress): def optimizer_steps(self) -> int: return self.optimizer.step.total.completed + @override def reset(self) -> None: self.optimizer.reset() @@ -270,5 +285,6 @@ class _OptimizationProgress(_BaseProgress): def reset_on_restart(self) -> None: self.optimizer.reset_on_restart() + @override def load_state_dict(self, state_dict: dict) -> None: self.optimizer.load_state_dict(state_dict["optimizer"]) diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 8b925d8f83..ad98c65300 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -15,6 +15,8 @@ import math from collections import OrderedDict from typing import Any, Dict, Optional, Union +from typing_extensions import override + import lightning.pytorch as pl from lightning.pytorch import loops # import as loops to avoid circular imports from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher @@ -304,11 +306,13 @@ class _TrainingEpochLoop(loops._Loop): self._results.cpu() self.val_loop.teardown() + @override def on_save_checkpoint(self) -> Dict: state_dict = super().on_save_checkpoint() state_dict["_batches_that_stepped"] = self._batches_that_stepped return state_dict + @override def on_load_checkpoint(self, state_dict: Dict) -> None: self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)