Change the return type of `tune()` in `trainer.py` to TypedDict (#13631)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
eec862ef2f
commit
8f56f8e0cd
|
@ -83,7 +83,6 @@ module = [
|
|||
"pytorch_lightning.trainer.supporters",
|
||||
"pytorch_lightning.trainer.trainer",
|
||||
"pytorch_lightning.tuner.batch_size_scaling",
|
||||
"pytorch_lightning.tuner.tuning",
|
||||
"pytorch_lightning.utilities.auto_restart",
|
||||
"pytorch_lightning.utilities.data",
|
||||
"pytorch_lightning.utilities.distributed",
|
||||
|
|
|
@ -84,8 +84,7 @@ from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
|
|||
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
|
||||
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.tuner.lr_finder import _LRFinder
|
||||
from pytorch_lightning.tuner.tuning import Tuner
|
||||
from pytorch_lightning.tuner.tuning import _TunerResult, Tuner
|
||||
from pytorch_lightning.utilities import (
|
||||
_HPU_AVAILABLE,
|
||||
_IPU_AVAILABLE,
|
||||
|
@ -1014,7 +1013,7 @@ class Trainer(
|
|||
datamodule: Optional[LightningDataModule] = None,
|
||||
scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
|
||||
lr_find_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Optional[Union[int, _LRFinder]]]:
|
||||
) -> _TunerResult:
|
||||
r"""
|
||||
Runs routines to tune hyperparameters before training.
|
||||
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.trainer.states import TrainerStatus
|
||||
from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size
|
||||
|
@ -21,6 +23,11 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
|
||||
|
||||
|
||||
class _TunerResult(TypedDict):
|
||||
lr_find: NotRequired[Optional[_LRFinder]]
|
||||
scale_batch_size: NotRequired[Optional[int]]
|
||||
|
||||
|
||||
class Tuner:
|
||||
"""Tuner class to tune your model."""
|
||||
|
||||
|
@ -36,11 +43,11 @@ class Tuner:
|
|||
model: "pl.LightningModule",
|
||||
scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
|
||||
lr_find_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Optional[Union[int, _LRFinder]]]:
|
||||
) -> _TunerResult:
|
||||
scale_batch_size_kwargs = scale_batch_size_kwargs or {}
|
||||
lr_find_kwargs = lr_find_kwargs or {}
|
||||
# return a dict instead of a tuple so BC is not broken if a new tuning procedure is added
|
||||
result = {}
|
||||
result = _TunerResult()
|
||||
|
||||
self.trainer.strategy.connect(model)
|
||||
|
||||
|
|
Loading…
Reference in New Issue