Add `@override` for files in `src/lightning/pytorch/tuner` (#19005)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
This commit is contained in:
Victor Prins 2023-11-18 14:58:07 +01:00 committed by GitHub
parent de7faf976b
commit 8c1e9e499d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 0 deletions

View File

@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
import torch
from lightning_utilities.core.imports import RequirementCache
from typing_extensions import override
import lightning.pytorch as pl
from lightning.fabric.utilities.types import _TORCH_LRSCHEDULER
@ -379,6 +380,7 @@ class _LRCallback(Callback):
self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.progress_bar = None
@override
def on_train_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
) -> None:
@ -391,6 +393,7 @@ class _LRCallback(Callback):
self.lrs.append(trainer.lr_scheduler_configs[0].scheduler.lr[0]) # type: ignore[union-attr]
@override
def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
@ -456,6 +459,8 @@ class _LinearLR(_TORCH_LRSCHEDULER):
self.num_iter = num_iter
super().__init__(optimizer, last_epoch)
# mypy can't follow the _TORCH_LRSCHEDULER TypeAlias, so ignore "no base method" error
@override # type: ignore[misc]
def get_lr(self) -> List[float]:
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter
@ -492,6 +497,8 @@ class _ExponentialLR(_TORCH_LRSCHEDULER):
self.num_iter = num_iter
super().__init__(optimizer, last_epoch)
# mypy can't follow the _TORCH_LRSCHEDULER TypeAlias, so ignore "no base method" error
@override # type: ignore[misc]
def get_lr(self) -> List[float]:
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter