Change the return type of `tune()` in `trainer.py` to TypedDict (#13631)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
donlapark 2022-07-14 21:08:10 +07:00 committed by GitHub
parent eec862ef2f
commit 8f56f8e0cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 6 deletions

View File

@ -83,7 +83,6 @@ module = [
"pytorch_lightning.trainer.supporters",
"pytorch_lightning.trainer.trainer",
"pytorch_lightning.tuner.batch_size_scaling",
"pytorch_lightning.tuner.tuning",
"pytorch_lightning.utilities.auto_restart",
"pytorch_lightning.utilities.data",
"pytorch_lightning.utilities.distributed",

View File

@ -84,8 +84,7 @@ from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.tuner.lr_finder import _LRFinder
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.tuner.tuning import _TunerResult, Tuner
from pytorch_lightning.utilities import (
_HPU_AVAILABLE,
_IPU_AVAILABLE,
@ -1014,7 +1013,7 @@ class Trainer(
datamodule: Optional[LightningDataModule] = None,
scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
lr_find_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Optional[Union[int, _LRFinder]]]:
) -> _TunerResult:
r"""
Runs routines to tune hyperparameters before training.

View File

@ -13,6 +13,8 @@
# limitations under the License.
from typing import Any, Dict, Optional, Union
from typing_extensions import NotRequired, TypedDict
import pytorch_lightning as pl
from pytorch_lightning.trainer.states import TrainerStatus
from pytorch_lightning.tuner.batch_size_scaling import scale_batch_size
@ -21,6 +23,11 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
class _TunerResult(TypedDict):
lr_find: NotRequired[Optional[_LRFinder]]
scale_batch_size: NotRequired[Optional[int]]
class Tuner:
"""Tuner class to tune your model."""
@ -36,11 +43,11 @@ class Tuner:
model: "pl.LightningModule",
scale_batch_size_kwargs: Optional[Dict[str, Any]] = None,
lr_find_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Optional[Union[int, _LRFinder]]]:
) -> _TunerResult:
scale_batch_size_kwargs = scale_batch_size_kwargs or {}
lr_find_kwargs = lr_find_kwargs or {}
# return a dict instead of a tuple so BC is not broken if a new tuning procedure is added
result = {}
result = _TunerResult()
self.trainer.strategy.connect(model)