Add `@override` for files in `src/lightning/pytorch/loops` (#18966)

This commit is contained in:
Victor Prins 2023-11-08 19:02:34 +01:00 committed by GitHub
parent c524c0b01b
commit 2334a8acba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 39 additions and 0 deletions

View File

@ -14,6 +14,8 @@
from typing import Any, Iterator, List, Optional
from typing_extensions import override
from lightning.fabric.utilities.data import sized_len
from lightning.pytorch.utilities.combined_loader import _ITERATOR_RETURN, CombinedLoader
from lightning.pytorch.utilities.exceptions import MisconfigurationException
@ -44,11 +46,13 @@ class _DataFetcher(Iterator):
def setup(self, combined_loader: CombinedLoader) -> None:
self._combined_loader = combined_loader
@override
def __iter__(self) -> "_DataFetcher":
self.iterator = iter(self.combined_loader)
self.reset()
return self
@override
def __next__(self) -> _ITERATOR_RETURN:
assert self.iterator is not None
self._start_profiler()
@ -95,6 +99,7 @@ class _PrefetchDataFetcher(_DataFetcher):
self.prefetch_batches = prefetch_batches
self.batches: List[Any] = []
@override
def __iter__(self) -> "_PrefetchDataFetcher":
super().__iter__()
if self.length is not None:
@ -111,6 +116,7 @@ class _PrefetchDataFetcher(_DataFetcher):
break
return self
@override
def __next__(self) -> _ITERATOR_RETURN:
if self.batches:
# there are pre-fetched batches already from a previous `prefetching` call.
@ -130,6 +136,7 @@ class _PrefetchDataFetcher(_DataFetcher):
raise StopIteration
return batch
@override
def reset(self) -> None:
super().reset()
self.batches = []
@ -156,16 +163,19 @@ class _DataLoaderIterDataFetcher(_DataFetcher):
self._batch_idx: int = 0
self._dataloader_idx: int = 0
@override
def __iter__(self) -> "_DataLoaderIterDataFetcher":
super().__iter__()
self.iterator_wrapper = iter(_DataFetcherWrapper(self))
return self
@override
def __next__(self) -> Iterator["_DataFetcherWrapper"]: # type: ignore[override]
if self.done:
raise StopIteration
return self.iterator_wrapper
@override
def reset(self) -> None:
super().reset()
self._batch = None
@ -189,6 +199,7 @@ class _DataFetcherWrapper(Iterator):
def length(self) -> Optional[int]:
return self.data_fetcher.length
@override
def __next__(self) -> _ITERATOR_RETURN:
fetcher = self.data_fetcher
if fetcher.done:

View File

@ -15,6 +15,7 @@ import logging
from typing import Optional, Union
import torch
from typing_extensions import override
import lightning.pytorch as pl
from lightning.fabric.utilities.data import _set_sampler_epoch, sized_len
@ -117,6 +118,7 @@ class _FitLoop(_Loop):
return self.epoch_loop.max_steps
@_Loop.restarting.setter
@override
def restarting(self, restarting: bool) -> None:
# if the last epoch completely finished, we are not actually restarting
values = self.epoch_progress.current.ready, self.epoch_progress.current.started

View File

@ -18,6 +18,7 @@ from typing import Any, Callable, Dict, Mapping, Optional, OrderedDict
import torch
from torch import Tensor
from torch.optim import Optimizer
from typing_extensions import override
import lightning.pytorch as pl
from lightning.pytorch.loops.loop import _Loop
@ -81,6 +82,7 @@ class ClosureResult(OutputResult):
return cls(closure_loss, extra=extra)
@override
def asdict(self) -> Dict[str, Any]:
return {"loss": self.loss, **self.extra}
@ -121,6 +123,7 @@ class Closure(AbstractClosure[ClosureResult]):
self._backward_fn = backward_fn
self._zero_grad_fn = zero_grad_fn
@override
@torch.enable_grad()
def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
step_output = self._step_fn()
@ -136,6 +139,7 @@ class Closure(AbstractClosure[ClosureResult]):
return step_output
@override
def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
self._result = self.closure(*args, **kwargs)
return self._result.loss

View File

@ -17,6 +17,7 @@ from dataclasses import dataclass, field
from typing import Any, Dict
from torch import Tensor
from typing_extensions import override
import lightning.pytorch as pl
from lightning.pytorch.core.optimizer import do_nothing_closure
@ -59,6 +60,7 @@ class ManualResult(OutputResult):
return cls(extra=extra)
@override
def asdict(self) -> Dict[str, Any]:
return self.extra

View File

@ -14,6 +14,8 @@
from dataclasses import asdict, dataclass, field
from typing import Type
from typing_extensions import override
@dataclass
class _BaseProgress:
@ -51,6 +53,7 @@ class _ReadyCompletedTracker(_BaseProgress):
ready: int = 0
completed: int = 0
@override
def reset(self) -> None:
"""Reset the state."""
self.ready = 0
@ -81,10 +84,12 @@ class _StartedTracker(_ReadyCompletedTracker):
started: int = 0
@override
def reset(self) -> None:
super().reset()
self.started = 0
@override
def reset_on_restart(self) -> None:
super().reset_on_restart()
self.started = self.completed
@ -106,10 +111,12 @@ class _ProcessedTracker(_StartedTracker):
processed: int = 0
@override
def reset(self) -> None:
super().reset()
self.processed = 0
@override
def reset_on_restart(self) -> None:
super().reset_on_restart()
self.processed = self.completed
@ -157,6 +164,7 @@ class _Progress(_BaseProgress):
"""Utility function to easily create an instance from keyword arguments to both ``Tracker``s."""
return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs))
@override
def reset(self) -> None:
self.total.reset()
self.current.reset()
@ -167,6 +175,7 @@ class _Progress(_BaseProgress):
def reset_on_restart(self) -> None:
self.current.reset_on_restart()
@override
def load_state_dict(self, state_dict: dict) -> None:
self.total.load_state_dict(state_dict["total"])
self.current.load_state_dict(state_dict["current"])
@ -187,14 +196,17 @@ class _BatchProgress(_Progress):
is_last_batch: bool = False
@override
def reset(self) -> None:
super().reset()
self.is_last_batch = False
@override
def reset_on_run(self) -> None:
super().reset_on_run()
self.is_last_batch = False
@override
def load_state_dict(self, state_dict: dict) -> None:
super().load_state_dict(state_dict)
self.is_last_batch = state_dict["is_last_batch"]
@ -229,6 +241,7 @@ class _OptimizerProgress(_BaseProgress):
step: _Progress = field(default_factory=lambda: _Progress.from_defaults(_ReadyCompletedTracker))
zero_grad: _Progress = field(default_factory=lambda: _Progress.from_defaults(_StartedTracker))
@override
def reset(self) -> None:
self.step.reset()
self.zero_grad.reset()
@ -241,6 +254,7 @@ class _OptimizerProgress(_BaseProgress):
self.step.reset_on_restart()
self.zero_grad.reset_on_restart()
@override
def load_state_dict(self, state_dict: dict) -> None:
self.step.load_state_dict(state_dict["step"])
self.zero_grad.load_state_dict(state_dict["zero_grad"])
@ -261,6 +275,7 @@ class _OptimizationProgress(_BaseProgress):
def optimizer_steps(self) -> int:
return self.optimizer.step.total.completed
@override
def reset(self) -> None:
self.optimizer.reset()
@ -270,5 +285,6 @@ class _OptimizationProgress(_BaseProgress):
def reset_on_restart(self) -> None:
self.optimizer.reset_on_restart()
@override
def load_state_dict(self, state_dict: dict) -> None:
self.optimizer.load_state_dict(state_dict["optimizer"])

View File

@ -15,6 +15,8 @@ import math
from collections import OrderedDict
from typing import Any, Dict, Optional, Union
from typing_extensions import override
import lightning.pytorch as pl
from lightning.pytorch import loops # import as loops to avoid circular imports
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher
@ -304,11 +306,13 @@ class _TrainingEpochLoop(loops._Loop):
self._results.cpu()
self.val_loop.teardown()
@override
def on_save_checkpoint(self) -> Dict:
state_dict = super().on_save_checkpoint()
state_dict["_batches_that_stepped"] = self._batches_that_stepped
return state_dict
@override
def on_load_checkpoint(self, state_dict: Dict) -> None:
self._batches_that_stepped = state_dict.get("_batches_that_stepped", 0)