From bd3fb2e66ebe1a52f3aaeb1b64f50c212705a6c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 29 Nov 2021 12:03:47 +0100 Subject: [PATCH] Fix typing in `pl.core.mixins.hparams_mixin` (#10800) * fix typing in hparams mixin * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unused import Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pyproject.toml | 1 - pytorch_lightning/core/mixins/hparams_mixin.py | 14 ++++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0e56d3a3db..0315bb0373 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,6 @@ module = [ "pytorch_lightning.core.decorators", "pytorch_lightning.core.lightning", "pytorch_lightning.core.mixins.device_dtype_mixin", - "pytorch_lightning.core.mixins.hparams_mixin", "pytorch_lightning.core.saving", "pytorch_lightning.distributed.dist", "pytorch_lightning.loggers.base", diff --git a/pytorch_lightning/core/mixins/hparams_mixin.py b/pytorch_lightning/core/mixins/hparams_mixin.py index 0e722f2bdb..8a0dd34a55 100644 --- a/pytorch_lightning/core/mixins/hparams_mixin.py +++ b/pytorch_lightning/core/mixins/hparams_mixin.py @@ -15,7 +15,7 @@ import copy import inspect import types from argparse import Namespace -from typing import MutableMapping, Optional, Sequence, Union +from typing import Any, MutableMapping, Optional, Sequence, Union from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES from pytorch_lightning.utilities import AttributeDict @@ -32,7 +32,7 @@ class HyperparametersMixin: def save_hyperparameters( self, - *args, + *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None, logger: bool = True, @@ -101,7 +101,9 @@ class HyperparametersMixin: self._log_hyperparams = logger # the frame needs to be created in this file. if not frame: - frame = inspect.currentframe().f_back + current_frame = inspect.currentframe() + if current_frame: + frame = current_frame.f_back save_hyperparameters(self, *args, ignore=ignore, frame=frame) def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None: @@ -113,7 +115,7 @@ class HyperparametersMixin: self._hparams = hp @staticmethod - def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]): + def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]) -> Union[MutableMapping, AttributeDict]: if isinstance(hp, Namespace): hp = vars(hp) if isinstance(hp, dict): @@ -125,12 +127,12 @@ class HyperparametersMixin: return hp @property - def hparams(self) -> Union[AttributeDict, dict, Namespace]: + def hparams(self) -> Union[AttributeDict, MutableMapping]: """The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user. For the frozen set of initial hyperparameters, use :attr:`hparams_initial`. Returns: - Union[AttributeDict, dict, Namespace]: mutable hyperparameters dicionary + Mutable hyperparameters dicionary """ if not hasattr(self, "_hparams"): self._hparams = AttributeDict()