Fix mypy in `utilities.parsing` (#8132)
This commit is contained in:
parent
9bbca402ff
commit
667def8d89
|
@ -17,8 +17,11 @@ import pickle
|
||||||
import types
|
import types
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from dataclasses import fields, is_dataclass
|
from dataclasses import fields, is_dataclass
|
||||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||||
|
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
import pytorch_lightning as pl
|
||||||
from pytorch_lightning.utilities.warnings import rank_zero_warn
|
from pytorch_lightning.utilities.warnings import rank_zero_warn
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,10 +56,10 @@ def str_to_bool(val: str) -> bool:
|
||||||
>>> str_to_bool('FALSE')
|
>>> str_to_bool('FALSE')
|
||||||
False
|
False
|
||||||
"""
|
"""
|
||||||
val = str_to_bool_or_str(val)
|
val_converted = str_to_bool_or_str(val)
|
||||||
if isinstance(val, bool):
|
if isinstance(val_converted, bool):
|
||||||
return val
|
return val_converted
|
||||||
raise ValueError(f'invalid truth value {val}')
|
raise ValueError(f'invalid truth value {val_converted}')
|
||||||
|
|
||||||
|
|
||||||
def str_to_bool_or_int(val: str) -> Union[bool, int, str]:
|
def str_to_bool_or_int(val: str) -> Union[bool, int, str]:
|
||||||
|
@ -71,13 +74,13 @@ def str_to_bool_or_int(val: str) -> Union[bool, int, str]:
|
||||||
>>> str_to_bool_or_int("abc")
|
>>> str_to_bool_or_int("abc")
|
||||||
'abc'
|
'abc'
|
||||||
"""
|
"""
|
||||||
val = str_to_bool_or_str(val)
|
val_converted = str_to_bool_or_str(val)
|
||||||
if isinstance(val, bool):
|
if isinstance(val_converted, bool):
|
||||||
return val
|
return val_converted
|
||||||
try:
|
try:
|
||||||
return int(val)
|
return int(val_converted)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return val
|
return val_converted
|
||||||
|
|
||||||
|
|
||||||
def is_picklable(obj: object) -> bool:
|
def is_picklable(obj: object) -> bool:
|
||||||
|
@ -90,7 +93,7 @@ def is_picklable(obj: object) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def clean_namespace(hparams):
|
def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None:
|
||||||
"""Removes all unpicklable entries from hparams"""
|
"""Removes all unpicklable entries from hparams"""
|
||||||
|
|
||||||
hparams_dict = hparams
|
hparams_dict = hparams
|
||||||
|
@ -104,7 +107,7 @@ def clean_namespace(hparams):
|
||||||
del hparams_dict[k]
|
del hparams_dict[k]
|
||||||
|
|
||||||
|
|
||||||
def parse_class_init_keys(cls) -> Tuple[str, str, str]:
|
def parse_class_init_keys(cls: Type['pl.LightningModule']) -> Tuple[str, Optional[str], Optional[str]]:
|
||||||
"""Parse key words for standard self, *args and **kwargs
|
"""Parse key words for standard self, *args and **kwargs
|
||||||
|
|
||||||
>>> class Model():
|
>>> class Model():
|
||||||
|
@ -120,10 +123,14 @@ def parse_class_init_keys(cls) -> Tuple[str, str, str]:
|
||||||
# self is always first
|
# self is always first
|
||||||
n_self = init_params[0].name
|
n_self = init_params[0].name
|
||||||
|
|
||||||
def _get_first_if_any(params, param_type):
|
def _get_first_if_any(
|
||||||
|
params: List[inspect.Parameter],
|
||||||
|
param_type: Literal[inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD],
|
||||||
|
) -> Optional[str]:
|
||||||
for p in params:
|
for p in params:
|
||||||
if p.kind == param_type:
|
if p.kind == param_type:
|
||||||
return p.name
|
return p.name
|
||||||
|
return None
|
||||||
|
|
||||||
n_args = _get_first_if_any(init_params, inspect.Parameter.VAR_POSITIONAL)
|
n_args = _get_first_if_any(init_params, inspect.Parameter.VAR_POSITIONAL)
|
||||||
n_kwargs = _get_first_if_any(init_params, inspect.Parameter.VAR_KEYWORD)
|
n_kwargs = _get_first_if_any(init_params, inspect.Parameter.VAR_KEYWORD)
|
||||||
|
@ -131,7 +138,7 @@ def parse_class_init_keys(cls) -> Tuple[str, str, str]:
|
||||||
return n_self, n_args, n_kwargs
|
return n_self, n_args, n_kwargs
|
||||||
|
|
||||||
|
|
||||||
def get_init_args(frame) -> dict:
|
def get_init_args(frame: types.FrameType) -> Dict[str, Any]:
|
||||||
_, _, _, local_vars = inspect.getargvalues(frame)
|
_, _, _, local_vars = inspect.getargvalues(frame)
|
||||||
if '__class__' not in local_vars:
|
if '__class__' not in local_vars:
|
||||||
return {}
|
return {}
|
||||||
|
@ -142,12 +149,18 @@ def get_init_args(frame) -> dict:
|
||||||
exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args')
|
exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args')
|
||||||
# only collect variables that appear in the signature
|
# only collect variables that appear in the signature
|
||||||
local_args = {k: local_vars[k] for k in init_parameters.keys()}
|
local_args = {k: local_vars[k] for k in init_parameters.keys()}
|
||||||
local_args.update(local_args.get(kwargs_var, {}))
|
# kwargs_var might be None => raised an error by mypy
|
||||||
|
if kwargs_var:
|
||||||
|
local_args.update(local_args.get(kwargs_var, {}))
|
||||||
local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames}
|
local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames}
|
||||||
return local_args
|
return local_args
|
||||||
|
|
||||||
|
|
||||||
def collect_init_args(frame, path_args: list, inside: bool = False) -> list:
|
def collect_init_args(
|
||||||
|
frame: types.FrameType,
|
||||||
|
path_args: List[Dict[str, Any]],
|
||||||
|
inside: bool = False,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Recursively collects the arguments passed to the child constructors in the inheritance tree.
|
Recursively collects the arguments passed to the child constructors in the inheritance tree.
|
||||||
|
|
||||||
|
@ -162,6 +175,10 @@ def collect_init_args(frame, path_args: list, inside: bool = False) -> list:
|
||||||
most specific class in the hierarchy.
|
most specific class in the hierarchy.
|
||||||
"""
|
"""
|
||||||
_, _, _, local_vars = inspect.getargvalues(frame)
|
_, _, _, local_vars = inspect.getargvalues(frame)
|
||||||
|
# frame.f_back must be of a type types.FrameType for get_init_args/collect_init_args due to mypy
|
||||||
|
if not isinstance(frame.f_back, types.FrameType):
|
||||||
|
return path_args
|
||||||
|
|
||||||
if '__class__' in local_vars:
|
if '__class__' in local_vars:
|
||||||
local_args = get_init_args(frame)
|
local_args = get_init_args(frame)
|
||||||
# recursive update
|
# recursive update
|
||||||
|
@ -172,7 +189,7 @@ def collect_init_args(frame, path_args: list, inside: bool = False) -> list:
|
||||||
return path_args
|
return path_args
|
||||||
|
|
||||||
|
|
||||||
def flatten_dict(source, result=None):
|
def flatten_dict(source: Dict[str, Any], result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||||
if result is None:
|
if result is None:
|
||||||
result = {}
|
result = {}
|
||||||
|
|
||||||
|
@ -187,7 +204,7 @@ def flatten_dict(source, result=None):
|
||||||
|
|
||||||
def save_hyperparameters(
|
def save_hyperparameters(
|
||||||
obj: Any,
|
obj: Any,
|
||||||
*args,
|
*args: Any,
|
||||||
ignore: Optional[Union[Sequence[str], str]] = None,
|
ignore: Optional[Union[Sequence[str], str]] = None,
|
||||||
frame: Optional[types.FrameType] = None
|
frame: Optional[types.FrameType] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -198,7 +215,12 @@ def save_hyperparameters(
|
||||||
return
|
return
|
||||||
|
|
||||||
if not frame:
|
if not frame:
|
||||||
frame = inspect.currentframe().f_back
|
current_frame = inspect.currentframe()
|
||||||
|
# inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available
|
||||||
|
if current_frame:
|
||||||
|
frame = current_frame.f_back
|
||||||
|
if not isinstance(frame, types.FrameType):
|
||||||
|
raise AttributeError("There is no `frame` available while being required.")
|
||||||
|
|
||||||
if is_dataclass(obj):
|
if is_dataclass(obj):
|
||||||
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
|
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
|
||||||
|
@ -251,16 +273,16 @@ class AttributeDict(Dict):
|
||||||
"my-key": 3.14
|
"my-key": 3.14
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __getattr__(self, key):
|
def __getattr__(self, key: str) -> Optional[Any]:
|
||||||
try:
|
try:
|
||||||
return self[key]
|
return self[key]
|
||||||
except KeyError as exp:
|
except KeyError as exp:
|
||||||
raise AttributeError(f'Missing attribute "{key}"') from exp
|
raise AttributeError(f'Missing attribute "{key}"') from exp
|
||||||
|
|
||||||
def __setattr__(self, key, val):
|
def __setattr__(self, key: str, val: Any) -> None:
|
||||||
self[key] = val
|
self[key] = val
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
if not len(self):
|
if not len(self):
|
||||||
return ""
|
return ""
|
||||||
max_key_length = max([len(str(k)) for k in self])
|
max_key_length = max([len(str(k)) for k in self])
|
||||||
|
@ -270,14 +292,14 @@ class AttributeDict(Dict):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _lightning_get_all_attr_holders(model, attribute):
|
def _lightning_get_all_attr_holders(model: 'pl.LightningModule', attribute: str) -> List[Any]:
|
||||||
"""
|
"""
|
||||||
Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute.
|
Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute.
|
||||||
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule.
|
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule.
|
||||||
"""
|
"""
|
||||||
trainer = getattr(model, 'trainer', None)
|
trainer = getattr(model, 'trainer', None)
|
||||||
|
|
||||||
holders = []
|
holders: List[Any] = []
|
||||||
|
|
||||||
# Check if attribute in model
|
# Check if attribute in model
|
||||||
if hasattr(model, attribute):
|
if hasattr(model, attribute):
|
||||||
|
@ -295,7 +317,7 @@ def _lightning_get_all_attr_holders(model, attribute):
|
||||||
return holders
|
return holders
|
||||||
|
|
||||||
|
|
||||||
def _lightning_get_first_attr_holder(model, attribute):
|
def _lightning_get_first_attr_holder(model: 'pl.LightningModule', attribute: str) -> Optional[Any]:
|
||||||
"""
|
"""
|
||||||
Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None.
|
Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None.
|
||||||
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule,
|
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule,
|
||||||
|
@ -308,7 +330,7 @@ def _lightning_get_first_attr_holder(model, attribute):
|
||||||
return holders[-1]
|
return holders[-1]
|
||||||
|
|
||||||
|
|
||||||
def lightning_hasattr(model, attribute):
|
def lightning_hasattr(model: 'pl.LightningModule', attribute: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Special hasattr for Lightning. Checks for attribute in model namespace,
|
Special hasattr for Lightning. Checks for attribute in model namespace,
|
||||||
the old hparams namespace/dict, and the datamodule.
|
the old hparams namespace/dict, and the datamodule.
|
||||||
|
@ -316,7 +338,7 @@ def lightning_hasattr(model, attribute):
|
||||||
return _lightning_get_first_attr_holder(model, attribute) is not None
|
return _lightning_get_first_attr_holder(model, attribute) is not None
|
||||||
|
|
||||||
|
|
||||||
def lightning_getattr(model, attribute):
|
def lightning_getattr(model: 'pl.LightningModule', attribute: str) -> Optional[Any]:
|
||||||
"""
|
"""
|
||||||
Special getattr for Lightning. Checks for attribute in model namespace,
|
Special getattr for Lightning. Checks for attribute in model namespace,
|
||||||
the old hparams namespace/dict, and the datamodule.
|
the old hparams namespace/dict, and the datamodule.
|
||||||
|
@ -338,7 +360,7 @@ def lightning_getattr(model, attribute):
|
||||||
return getattr(holder, attribute)
|
return getattr(holder, attribute)
|
||||||
|
|
||||||
|
|
||||||
def lightning_setattr(model, attribute, value):
|
def lightning_setattr(model: 'pl.LightningModule', attribute: str, value: Any) -> None:
|
||||||
"""
|
"""
|
||||||
Special setattr for Lightning. Checks for attribute in model namespace
|
Special setattr for Lightning. Checks for attribute in model namespace
|
||||||
and the old hparams namespace/dict.
|
and the old hparams namespace/dict.
|
||||||
|
|
|
@ -185,6 +185,8 @@ ignore_errors = True
|
||||||
ignore_errors = False
|
ignore_errors = False
|
||||||
[mypy-pytorch_lightning.utilities.device_parser]
|
[mypy-pytorch_lightning.utilities.device_parser]
|
||||||
ignore_errors = False
|
ignore_errors = False
|
||||||
|
[mypy-pytorch_lightning.utilities.parsing]
|
||||||
|
ignore_errors = False
|
||||||
|
|
||||||
# todo: add proper typing to this module...
|
# todo: add proper typing to this module...
|
||||||
[mypy-pl_examples.*]
|
[mypy-pl_examples.*]
|
||||||
|
|
Loading…
Reference in New Issue