diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index 754592b891..f39788b8ea 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast import torch from lightning_utilities.core.imports import RequirementCache +from typing_extensions import override import lightning.pytorch as pl from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER @@ -379,6 +380,7 @@ class _LRCallback(Callback): self.progress_bar_refresh_rate = progress_bar_refresh_rate self.progress_bar = None + @override def on_train_batch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int ) -> None: @@ -391,6 +393,7 @@ class _LRCallback(Callback): self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0]) # type: ignore[union-attr] + @override def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: @@ -456,6 +459,8 @@ class _LinearLR(_TORCH_LRSCHEDULER): self.num_iter = num_iter super().__init__(optimizer, last_epoch) + # mypy can't follow the _TORCH_LRSCHEDULER TypeAlias, so ignore "no base method" error + @override # type: ignore[misc] def get_lr(self) -> List[float]: curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter @@ -492,6 +497,8 @@ class _ExponentialLR(_TORCH_LRSCHEDULER): self.num_iter = num_iter super().__init__(optimizer, last_epoch) + # mypy can't follow the _TORCH_LRSCHEDULER TypeAlias, so ignore "no base method" error + @override # type: ignore[misc] def get_lr(self) -> List[float]: curr_iter = self.last_epoch + 1 r = curr_iter / self.num_iter