Fix mypy in `utilities.parsing` (#8132)
This commit is contained in:
parent
9bbca402ff
commit
667def8d89
|
@ -17,8 +17,11 @@ import pickle
|
|||
import types
|
||||
from argparse import Namespace
|
||||
from dataclasses import fields, is_dataclass
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities.warnings import rank_zero_warn
|
||||
|
||||
|
||||
|
@ -53,10 +56,10 @@ def str_to_bool(val: str) -> bool:
|
|||
>>> str_to_bool('FALSE')
|
||||
False
|
||||
"""
|
||||
val = str_to_bool_or_str(val)
|
||||
if isinstance(val, bool):
|
||||
return val
|
||||
raise ValueError(f'invalid truth value {val}')
|
||||
val_converted = str_to_bool_or_str(val)
|
||||
if isinstance(val_converted, bool):
|
||||
return val_converted
|
||||
raise ValueError(f'invalid truth value {val_converted}')
|
||||
|
||||
|
||||
def str_to_bool_or_int(val: str) -> Union[bool, int, str]:
|
||||
|
@ -71,13 +74,13 @@ def str_to_bool_or_int(val: str) -> Union[bool, int, str]:
|
|||
>>> str_to_bool_or_int("abc")
|
||||
'abc'
|
||||
"""
|
||||
val = str_to_bool_or_str(val)
|
||||
if isinstance(val, bool):
|
||||
return val
|
||||
val_converted = str_to_bool_or_str(val)
|
||||
if isinstance(val_converted, bool):
|
||||
return val_converted
|
||||
try:
|
||||
return int(val)
|
||||
return int(val_converted)
|
||||
except ValueError:
|
||||
return val
|
||||
return val_converted
|
||||
|
||||
|
||||
def is_picklable(obj: object) -> bool:
|
||||
|
@ -90,7 +93,7 @@ def is_picklable(obj: object) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def clean_namespace(hparams):
|
||||
def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None:
|
||||
"""Removes all unpicklable entries from hparams"""
|
||||
|
||||
hparams_dict = hparams
|
||||
|
@ -104,7 +107,7 @@ def clean_namespace(hparams):
|
|||
del hparams_dict[k]
|
||||
|
||||
|
||||
def parse_class_init_keys(cls) -> Tuple[str, str, str]:
|
||||
def parse_class_init_keys(cls: Type['pl.LightningModule']) -> Tuple[str, Optional[str], Optional[str]]:
|
||||
"""Parse key words for standard self, *args and **kwargs
|
||||
|
||||
>>> class Model():
|
||||
|
@ -120,10 +123,14 @@ def parse_class_init_keys(cls) -> Tuple[str, str, str]:
|
|||
# self is always first
|
||||
n_self = init_params[0].name
|
||||
|
||||
def _get_first_if_any(params, param_type):
|
||||
def _get_first_if_any(
|
||||
params: List[inspect.Parameter],
|
||||
param_type: Literal[inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD],
|
||||
) -> Optional[str]:
|
||||
for p in params:
|
||||
if p.kind == param_type:
|
||||
return p.name
|
||||
return None
|
||||
|
||||
n_args = _get_first_if_any(init_params, inspect.Parameter.VAR_POSITIONAL)
|
||||
n_kwargs = _get_first_if_any(init_params, inspect.Parameter.VAR_KEYWORD)
|
||||
|
@ -131,7 +138,7 @@ def parse_class_init_keys(cls) -> Tuple[str, str, str]:
|
|||
return n_self, n_args, n_kwargs
|
||||
|
||||
|
||||
def get_init_args(frame) -> dict:
|
||||
def get_init_args(frame: types.FrameType) -> Dict[str, Any]:
|
||||
_, _, _, local_vars = inspect.getargvalues(frame)
|
||||
if '__class__' not in local_vars:
|
||||
return {}
|
||||
|
@ -142,12 +149,18 @@ def get_init_args(frame) -> dict:
|
|||
exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args')
|
||||
# only collect variables that appear in the signature
|
||||
local_args = {k: local_vars[k] for k in init_parameters.keys()}
|
||||
local_args.update(local_args.get(kwargs_var, {}))
|
||||
# kwargs_var might be None => raised an error by mypy
|
||||
if kwargs_var:
|
||||
local_args.update(local_args.get(kwargs_var, {}))
|
||||
local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames}
|
||||
return local_args
|
||||
|
||||
|
||||
def collect_init_args(frame, path_args: list, inside: bool = False) -> list:
|
||||
def collect_init_args(
|
||||
frame: types.FrameType,
|
||||
path_args: List[Dict[str, Any]],
|
||||
inside: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Recursively collects the arguments passed to the child constructors in the inheritance tree.
|
||||
|
||||
|
@ -162,6 +175,10 @@ def collect_init_args(frame, path_args: list, inside: bool = False) -> list:
|
|||
most specific class in the hierarchy.
|
||||
"""
|
||||
_, _, _, local_vars = inspect.getargvalues(frame)
|
||||
# frame.f_back must be of a type types.FrameType for get_init_args/collect_init_args due to mypy
|
||||
if not isinstance(frame.f_back, types.FrameType):
|
||||
return path_args
|
||||
|
||||
if '__class__' in local_vars:
|
||||
local_args = get_init_args(frame)
|
||||
# recursive update
|
||||
|
@ -172,7 +189,7 @@ def collect_init_args(frame, path_args: list, inside: bool = False) -> list:
|
|||
return path_args
|
||||
|
||||
|
||||
def flatten_dict(source, result=None):
|
||||
def flatten_dict(source: Dict[str, Any], result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
if result is None:
|
||||
result = {}
|
||||
|
||||
|
@ -187,7 +204,7 @@ def flatten_dict(source, result=None):
|
|||
|
||||
def save_hyperparameters(
|
||||
obj: Any,
|
||||
*args,
|
||||
*args: Any,
|
||||
ignore: Optional[Union[Sequence[str], str]] = None,
|
||||
frame: Optional[types.FrameType] = None
|
||||
) -> None:
|
||||
|
@ -198,7 +215,12 @@ def save_hyperparameters(
|
|||
return
|
||||
|
||||
if not frame:
|
||||
frame = inspect.currentframe().f_back
|
||||
current_frame = inspect.currentframe()
|
||||
# inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available
|
||||
if current_frame:
|
||||
frame = current_frame.f_back
|
||||
if not isinstance(frame, types.FrameType):
|
||||
raise AttributeError("There is no `frame` available while being required.")
|
||||
|
||||
if is_dataclass(obj):
|
||||
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
|
||||
|
@ -251,16 +273,16 @@ class AttributeDict(Dict):
|
|||
"my-key": 3.14
|
||||
"""
|
||||
|
||||
def __getattr__(self, key):
|
||||
def __getattr__(self, key: str) -> Optional[Any]:
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError as exp:
|
||||
raise AttributeError(f'Missing attribute "{key}"') from exp
|
||||
|
||||
def __setattr__(self, key, val):
|
||||
def __setattr__(self, key: str, val: Any) -> None:
|
||||
self[key] = val
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
if not len(self):
|
||||
return ""
|
||||
max_key_length = max([len(str(k)) for k in self])
|
||||
|
@ -270,14 +292,14 @@ class AttributeDict(Dict):
|
|||
return out
|
||||
|
||||
|
||||
def _lightning_get_all_attr_holders(model, attribute):
|
||||
def _lightning_get_all_attr_holders(model: 'pl.LightningModule', attribute: str) -> List[Any]:
|
||||
"""
|
||||
Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute.
|
||||
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule.
|
||||
"""
|
||||
trainer = getattr(model, 'trainer', None)
|
||||
|
||||
holders = []
|
||||
holders: List[Any] = []
|
||||
|
||||
# Check if attribute in model
|
||||
if hasattr(model, attribute):
|
||||
|
@ -295,7 +317,7 @@ def _lightning_get_all_attr_holders(model, attribute):
|
|||
return holders
|
||||
|
||||
|
||||
def _lightning_get_first_attr_holder(model, attribute):
|
||||
def _lightning_get_first_attr_holder(model: 'pl.LightningModule', attribute: str) -> Optional[Any]:
|
||||
"""
|
||||
Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None.
|
||||
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule,
|
||||
|
@ -308,7 +330,7 @@ def _lightning_get_first_attr_holder(model, attribute):
|
|||
return holders[-1]
|
||||
|
||||
|
||||
def lightning_hasattr(model, attribute):
|
||||
def lightning_hasattr(model: 'pl.LightningModule', attribute: str) -> bool:
|
||||
"""
|
||||
Special hasattr for Lightning. Checks for attribute in model namespace,
|
||||
the old hparams namespace/dict, and the datamodule.
|
||||
|
@ -316,7 +338,7 @@ def lightning_hasattr(model, attribute):
|
|||
return _lightning_get_first_attr_holder(model, attribute) is not None
|
||||
|
||||
|
||||
def lightning_getattr(model, attribute):
|
||||
def lightning_getattr(model: 'pl.LightningModule', attribute: str) -> Optional[Any]:
|
||||
"""
|
||||
Special getattr for Lightning. Checks for attribute in model namespace,
|
||||
the old hparams namespace/dict, and the datamodule.
|
||||
|
@ -338,7 +360,7 @@ def lightning_getattr(model, attribute):
|
|||
return getattr(holder, attribute)
|
||||
|
||||
|
||||
def lightning_setattr(model, attribute, value):
|
||||
def lightning_setattr(model: 'pl.LightningModule', attribute: str, value: Any) -> None:
|
||||
"""
|
||||
Special setattr for Lightning. Checks for attribute in model namespace
|
||||
and the old hparams namespace/dict.
|
||||
|
|
Loading…
Reference in New Issue