Fix typing in `pl.core.mixins.hparams_mixin` (#10800)
* fix typing in hparams mixin * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unused import Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
97e52619ea
commit
bd3fb2e66e
|
@ -60,7 +60,6 @@ module = [
|
|||
"pytorch_lightning.core.decorators",
|
||||
"pytorch_lightning.core.lightning",
|
||||
"pytorch_lightning.core.mixins.device_dtype_mixin",
|
||||
"pytorch_lightning.core.mixins.hparams_mixin",
|
||||
"pytorch_lightning.core.saving",
|
||||
"pytorch_lightning.distributed.dist",
|
||||
"pytorch_lightning.loggers.base",
|
||||
|
|
|
@ -15,7 +15,7 @@ import copy
|
|||
import inspect
|
||||
import types
|
||||
from argparse import Namespace
|
||||
from typing import MutableMapping, Optional, Sequence, Union
|
||||
from typing import Any, MutableMapping, Optional, Sequence, Union
|
||||
|
||||
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES
|
||||
from pytorch_lightning.utilities import AttributeDict
|
||||
|
@ -32,7 +32,7 @@ class HyperparametersMixin:
|
|||
|
||||
def save_hyperparameters(
|
||||
self,
|
||||
*args,
|
||||
*args: Any,
|
||||
ignore: Optional[Union[Sequence[str], str]] = None,
|
||||
frame: Optional[types.FrameType] = None,
|
||||
logger: bool = True,
|
||||
|
@ -101,7 +101,9 @@ class HyperparametersMixin:
|
|||
self._log_hyperparams = logger
|
||||
# the frame needs to be created in this file.
|
||||
if not frame:
|
||||
frame = inspect.currentframe().f_back
|
||||
current_frame = inspect.currentframe()
|
||||
if current_frame:
|
||||
frame = current_frame.f_back
|
||||
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
|
||||
|
||||
def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None:
|
||||
|
@ -113,7 +115,7 @@ class HyperparametersMixin:
|
|||
self._hparams = hp
|
||||
|
||||
@staticmethod
|
||||
def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]):
|
||||
def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]) -> Union[MutableMapping, AttributeDict]:
|
||||
if isinstance(hp, Namespace):
|
||||
hp = vars(hp)
|
||||
if isinstance(hp, dict):
|
||||
|
@ -125,12 +127,12 @@ class HyperparametersMixin:
|
|||
return hp
|
||||
|
||||
@property
|
||||
def hparams(self) -> Union[AttributeDict, dict, Namespace]:
|
||||
def hparams(self) -> Union[AttributeDict, MutableMapping]:
|
||||
"""The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user.
|
||||
For the frozen set of initial hyperparameters, use :attr:`hparams_initial`.
|
||||
|
||||
Returns:
|
||||
Union[AttributeDict, dict, Namespace]: mutable hyperparameters dicionary
|
||||
Mutable hyperparameters dicionary
|
||||
"""
|
||||
if not hasattr(self, "_hparams"):
|
||||
self._hparams = AttributeDict()
|
||||
|
|
Loading…
Reference in New Issue