Add `@override` for files in `src/lightning/pytorch/tuner` (#19005)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
parent
de7faf976b
commit
8c1e9e499d
|
@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
|||
|
||||
import torch
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from typing_extensions import override
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
|
||||
|
@ -379,6 +380,7 @@ class _LRCallback(Callback):
|
|||
self.progress_bar_refresh_rate = progress_bar_refresh_rate
|
||||
self.progress_bar = None
|
||||
|
||||
@override
|
||||
def on_train_batch_start(
|
||||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
|
||||
) -> None:
|
||||
|
@ -391,6 +393,7 @@ class _LRCallback(Callback):
|
|||
|
||||
self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0]) # type: ignore[union-attr]
|
||||
|
||||
@override
|
||||
def on_train_batch_end(
|
||||
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
|
||||
) -> None:
|
||||
|
@ -456,6 +459,8 @@ class _LinearLR(_TORCH_LRSCHEDULER):
|
|||
self.num_iter = num_iter
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
# mypy can't follow the _TORCH_LRSCHEDULER TypeAlias, so ignore "no base method" error
|
||||
@override # type: ignore[misc]
|
||||
def get_lr(self) -> List[float]:
|
||||
curr_iter = self.last_epoch + 1
|
||||
r = curr_iter / self.num_iter
|
||||
|
@ -492,6 +497,8 @@ class _ExponentialLR(_TORCH_LRSCHEDULER):
|
|||
self.num_iter = num_iter
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
# mypy can't follow the _TORCH_LRSCHEDULER TypeAlias, so ignore "no base method" error
|
||||
@override # type: ignore[misc]
|
||||
def get_lr(self) -> List[float]:
|
||||
curr_iter = self.last_epoch + 1
|
||||
r = curr_iter / self.num_iter
|
||||
|
|
Loading…
Reference in New Issue