Fix typing in `pl.core.mixins.hparams_mixin` (#10800)

* fix typing in hparams mixin

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unused import

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-11-29 12:03:47 +01:00 committed by GitHub
parent 97e52619ea
commit bd3fb2e66e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 7 deletions

View File

@ -60,7 +60,6 @@ module = [
"pytorch_lightning.core.decorators", "pytorch_lightning.core.decorators",
"pytorch_lightning.core.lightning", "pytorch_lightning.core.lightning",
"pytorch_lightning.core.mixins.device_dtype_mixin", "pytorch_lightning.core.mixins.device_dtype_mixin",
"pytorch_lightning.core.mixins.hparams_mixin",
"pytorch_lightning.core.saving", "pytorch_lightning.core.saving",
"pytorch_lightning.distributed.dist", "pytorch_lightning.distributed.dist",
"pytorch_lightning.loggers.base", "pytorch_lightning.loggers.base",

View File

@ -15,7 +15,7 @@ import copy
import inspect import inspect
import types import types
from argparse import Namespace from argparse import Namespace
from typing import MutableMapping, Optional, Sequence, Union from typing import Any, MutableMapping, Optional, Sequence, Union
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES
from pytorch_lightning.utilities import AttributeDict from pytorch_lightning.utilities import AttributeDict
@ -32,7 +32,7 @@ class HyperparametersMixin:
def save_hyperparameters( def save_hyperparameters(
self, self,
*args, *args: Any,
ignore: Optional[Union[Sequence[str], str]] = None, ignore: Optional[Union[Sequence[str], str]] = None,
frame: Optional[types.FrameType] = None, frame: Optional[types.FrameType] = None,
logger: bool = True, logger: bool = True,
@ -101,7 +101,9 @@ class HyperparametersMixin:
self._log_hyperparams = logger self._log_hyperparams = logger
# the frame needs to be created in this file. # the frame needs to be created in this file.
if not frame: if not frame:
frame = inspect.currentframe().f_back current_frame = inspect.currentframe()
if current_frame:
frame = current_frame.f_back
save_hyperparameters(self, *args, ignore=ignore, frame=frame) save_hyperparameters(self, *args, ignore=ignore, frame=frame)
def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None: def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None:
@ -113,7 +115,7 @@ class HyperparametersMixin:
self._hparams = hp self._hparams = hp
@staticmethod @staticmethod
def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]): def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]) -> Union[MutableMapping, AttributeDict]:
if isinstance(hp, Namespace): if isinstance(hp, Namespace):
hp = vars(hp) hp = vars(hp)
if isinstance(hp, dict): if isinstance(hp, dict):
@ -125,12 +127,12 @@ class HyperparametersMixin:
return hp return hp
@property @property
def hparams(self) -> Union[AttributeDict, dict, Namespace]: def hparams(self) -> Union[AttributeDict, MutableMapping]:
"""The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user. """The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user.
For the frozen set of initial hyperparameters, use :attr:`hparams_initial`. For the frozen set of initial hyperparameters, use :attr:`hparams_initial`.
Returns: Returns:
Union[AttributeDict, dict, Namespace]: mutable hyperparameters dicionary Mutable hyperparameters dicionary
""" """
if not hasattr(self, "_hparams"): if not hasattr(self, "_hparams"):
self._hparams = AttributeDict() self._hparams = AttributeDict()