Fix typing in `pl.core.mixins.hparams_mixin` (#10800)
* fix typing in hparams mixin * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unused import Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
97e52619ea
commit
bd3fb2e66e
|
@ -60,7 +60,6 @@ module = [
|
||||||
"pytorch_lightning.core.decorators",
|
"pytorch_lightning.core.decorators",
|
||||||
"pytorch_lightning.core.lightning",
|
"pytorch_lightning.core.lightning",
|
||||||
"pytorch_lightning.core.mixins.device_dtype_mixin",
|
"pytorch_lightning.core.mixins.device_dtype_mixin",
|
||||||
"pytorch_lightning.core.mixins.hparams_mixin",
|
|
||||||
"pytorch_lightning.core.saving",
|
"pytorch_lightning.core.saving",
|
||||||
"pytorch_lightning.distributed.dist",
|
"pytorch_lightning.distributed.dist",
|
||||||
"pytorch_lightning.loggers.base",
|
"pytorch_lightning.loggers.base",
|
||||||
|
|
|
@ -15,7 +15,7 @@ import copy
|
||||||
import inspect
|
import inspect
|
||||||
import types
|
import types
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from typing import MutableMapping, Optional, Sequence, Union
|
from typing import Any, MutableMapping, Optional, Sequence, Union
|
||||||
|
|
||||||
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES
|
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES
|
||||||
from pytorch_lightning.utilities import AttributeDict
|
from pytorch_lightning.utilities import AttributeDict
|
||||||
|
@ -32,7 +32,7 @@ class HyperparametersMixin:
|
||||||
|
|
||||||
def save_hyperparameters(
|
def save_hyperparameters(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args: Any,
|
||||||
ignore: Optional[Union[Sequence[str], str]] = None,
|
ignore: Optional[Union[Sequence[str], str]] = None,
|
||||||
frame: Optional[types.FrameType] = None,
|
frame: Optional[types.FrameType] = None,
|
||||||
logger: bool = True,
|
logger: bool = True,
|
||||||
|
@ -101,7 +101,9 @@ class HyperparametersMixin:
|
||||||
self._log_hyperparams = logger
|
self._log_hyperparams = logger
|
||||||
# the frame needs to be created in this file.
|
# the frame needs to be created in this file.
|
||||||
if not frame:
|
if not frame:
|
||||||
frame = inspect.currentframe().f_back
|
current_frame = inspect.currentframe()
|
||||||
|
if current_frame:
|
||||||
|
frame = current_frame.f_back
|
||||||
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
|
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
|
||||||
|
|
||||||
def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None:
|
def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None:
|
||||||
|
@ -113,7 +115,7 @@ class HyperparametersMixin:
|
||||||
self._hparams = hp
|
self._hparams = hp
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]):
|
def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]) -> Union[MutableMapping, AttributeDict]:
|
||||||
if isinstance(hp, Namespace):
|
if isinstance(hp, Namespace):
|
||||||
hp = vars(hp)
|
hp = vars(hp)
|
||||||
if isinstance(hp, dict):
|
if isinstance(hp, dict):
|
||||||
|
@ -125,12 +127,12 @@ class HyperparametersMixin:
|
||||||
return hp
|
return hp
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hparams(self) -> Union[AttributeDict, dict, Namespace]:
|
def hparams(self) -> Union[AttributeDict, MutableMapping]:
|
||||||
"""The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user.
|
"""The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user.
|
||||||
For the frozen set of initial hyperparameters, use :attr:`hparams_initial`.
|
For the frozen set of initial hyperparameters, use :attr:`hparams_initial`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Union[AttributeDict, dict, Namespace]: mutable hyperparameters dicionary
|
Mutable hyperparameters dicionary
|
||||||
"""
|
"""
|
||||||
if not hasattr(self, "_hparams"):
|
if not hasattr(self, "_hparams"):
|
||||||
self._hparams = AttributeDict()
|
self._hparams = AttributeDict()
|
||||||
|
|
Loading…
Reference in New Issue