Fix typing in `pl.core.mixins.hparams_mixin` (#10800)

* fix typing in hparams mixin

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unused import

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Adrian Wälchli 2021-11-29 12:03:47 +01:00 committed by GitHub
parent 97e52619ea
commit bd3fb2e66e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 7 deletions

View File

@ -60,7 +60,6 @@ module = [
"pytorch_lightning.core.decorators",
"pytorch_lightning.core.lightning",
"pytorch_lightning.core.mixins.device_dtype_mixin",
"pytorch_lightning.core.mixins.hparams_mixin",
"pytorch_lightning.core.saving",
"pytorch_lightning.distributed.dist",
"pytorch_lightning.loggers.base",

View File

@ -15,7 +15,7 @@ import copy
import inspect
import types
from argparse import Namespace
from typing import MutableMapping, Optional, Sequence, Union
from typing import Any, MutableMapping, Optional, Sequence, Union
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES
from pytorch_lightning.utilities import AttributeDict
@ -32,7 +32,7 @@ class HyperparametersMixin:
def save_hyperparameters(
self,
*args,
*args: Any,
ignore: Optional[Union[Sequence[str], str]] = None,
frame: Optional[types.FrameType] = None,
logger: bool = True,
@ -101,7 +101,9 @@ class HyperparametersMixin:
self._log_hyperparams = logger
# the frame needs to be created in this file.
if not frame:
frame = inspect.currentframe().f_back
current_frame = inspect.currentframe()
if current_frame:
frame = current_frame.f_back
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None:
@ -113,7 +115,7 @@ class HyperparametersMixin:
self._hparams = hp
@staticmethod
def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]):
def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]) -> Union[MutableMapping, AttributeDict]:
if isinstance(hp, Namespace):
hp = vars(hp)
if isinstance(hp, dict):
@ -125,12 +127,12 @@ class HyperparametersMixin:
return hp
@property
def hparams(self) -> Union[AttributeDict, dict, Namespace]:
def hparams(self) -> Union[AttributeDict, MutableMapping]:
"""The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user.
For the frozen set of initial hyperparameters, use :attr:`hparams_initial`.
Returns:
Union[AttributeDict, dict, Namespace]: mutable hyperparameters dicionary
Mutable hyperparameters dicionary
"""
if not hasattr(self, "_hparams"):
self._hparams = AttributeDict()