diff --git a/pyproject.toml b/pyproject.toml index 45e9eaafc4..e886795e23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,6 @@ module = [ "pytorch_lightning.trainer.supporters", "pytorch_lightning.trainer.trainer", "pytorch_lightning.tuner.batch_size_scaling", - "pytorch_lightning.tuner.tuning", "pytorch_lightning.utilities.auto_restart", "pytorch_lightning.utilities.data", "pytorch_lightning.utilities.distributed", diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index d2c4135026..02cb951e1e 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -84,8 +84,7 @@ from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus from pytorch_lightning.trainer.supporters import CombinedLoader -from pytorch_lightning.tuner.lr_finder import _LRFinder -from pytorch_lightning.tuner.tuning import Tuner +from pytorch_lightning.tuner.tuning import _TunerResult, Tuner from pytorch_lightning.utilities import ( _HPU_AVAILABLE, _IPU_AVAILABLE, @@ -1014,7 +1013,7 @@ class Trainer( datamodule: Optional[LightningDataModule] = None, scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, lr_find_kwargs: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Optional[Union[int, _LRFinder]]]: + ) -> _TunerResult: r""" Runs routines to tune hyperparameters before training. diff --git a/src/pytorch_lightning/tuner/tuning.py b/src/pytorch_lightning/tuner/tuning.py index b1a38bd276..79ebf3bd0d 100644 --- a/src/pytorch_lightning/tuner/tuning.py +++ b/src/pytorch_lightning/tuner/tuning.py @@ -13,6 +13,8 @@ # limitations under the License. from typing import Any, Dict, Optional, Union +from typing_extensions import NotRequired, TypedDict + import pytorch_lightning as pl from pytorch_lightning.trainer.states import TrainerStatus from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size @@ -21,6 +23,11 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +class _TunerResult(TypedDict): + lr_find: NotRequired[Optional[_LRFinder]] + scale_batch_size: NotRequired[Optional[int]] + + class Tuner: """Tuner class to tune your model.""" @@ -36,11 +43,11 @@ class Tuner: model: "pl.LightningModule", scale_batch_size_kwargs: Optional[Dict[str, Any]] = None, lr_find_kwargs: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Optional[Union[int, _LRFinder]]]: + ) -> _TunerResult: scale_batch_size_kwargs = scale_batch_size_kwargs or {} lr_find_kwargs = lr_find_kwargs or {} # return a dict instead of a tuple so BC is not broken if a new tuning procedure is added - result = {} + result = _TunerResult() self.trainer.strategy.connect(model)