Fix mypy in `utilities.parsing` (#8132)

This commit is contained in:
Daniel Stancl 2021-07-08 01:32:12 +02:00 committed by GitHub
parent 9bbca402ff
commit 667def8d89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 28 deletions

View File

@ -17,8 +17,11 @@ import pickle
import types
from argparse import Namespace
from dataclasses import fields, is_dataclass
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing_extensions import Literal
import pytorch_lightning as pl
from pytorch_lightning.utilities.warnings import rank_zero_warn
@ -53,10 +56,10 @@ def str_to_bool(val: str) -> bool:
>>> str_to_bool('FALSE')
False
"""
val = str_to_bool_or_str(val)
if isinstance(val, bool):
return val
raise ValueError(f'invalid truth value {val}')
val_converted = str_to_bool_or_str(val)
if isinstance(val_converted, bool):
return val_converted
raise ValueError(f'invalid truth value {val_converted}')
def str_to_bool_or_int(val: str) -> Union[bool, int, str]:
@ -71,13 +74,13 @@ def str_to_bool_or_int(val: str) -> Union[bool, int, str]:
>>> str_to_bool_or_int("abc")
'abc'
"""
val = str_to_bool_or_str(val)
if isinstance(val, bool):
return val
val_converted = str_to_bool_or_str(val)
if isinstance(val_converted, bool):
return val_converted
try:
return int(val)
return int(val_converted)
except ValueError:
return val
return val_converted
def is_picklable(obj: object) -> bool:
@ -90,7 +93,7 @@ def is_picklable(obj: object) -> bool:
return False
def clean_namespace(hparams):
def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None:
"""Removes all unpicklable entries from hparams"""
hparams_dict = hparams
@ -104,7 +107,7 @@ def clean_namespace(hparams):
del hparams_dict[k]
def parse_class_init_keys(cls) -> Tuple[str, str, str]:
def parse_class_init_keys(cls: Type['pl.LightningModule']) -> Tuple[str, Optional[str], Optional[str]]:
"""Parse key words for standard self, *args and **kwargs
>>> class Model():
@ -120,10 +123,14 @@ def parse_class_init_keys(cls) -> Tuple[str, str, str]:
# self is always first
n_self = init_params[0].name
def _get_first_if_any(params, param_type):
def _get_first_if_any(
params: List[inspect.Parameter],
param_type: Literal[inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD],
) -> Optional[str]:
for p in params:
if p.kind == param_type:
return p.name
return None
n_args = _get_first_if_any(init_params, inspect.Parameter.VAR_POSITIONAL)
n_kwargs = _get_first_if_any(init_params, inspect.Parameter.VAR_KEYWORD)
@ -131,7 +138,7 @@ def parse_class_init_keys(cls) -> Tuple[str, str, str]:
return n_self, n_args, n_kwargs
def get_init_args(frame) -> dict:
def get_init_args(frame: types.FrameType) -> Dict[str, Any]:
_, _, _, local_vars = inspect.getargvalues(frame)
if '__class__' not in local_vars:
return {}
@ -142,12 +149,18 @@ def get_init_args(frame) -> dict:
exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args')
# only collect variables that appear in the signature
local_args = {k: local_vars[k] for k in init_parameters.keys()}
local_args.update(local_args.get(kwargs_var, {}))
# kwargs_var might be None => raised an error by mypy
if kwargs_var:
local_args.update(local_args.get(kwargs_var, {}))
local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames}
return local_args
def collect_init_args(frame, path_args: list, inside: bool = False) -> list:
def collect_init_args(
frame: types.FrameType,
path_args: List[Dict[str, Any]],
inside: bool = False,
) -> List[Dict[str, Any]]:
"""
Recursively collects the arguments passed to the child constructors in the inheritance tree.
@ -162,6 +175,10 @@ def collect_init_args(frame, path_args: list, inside: bool = False) -> list:
most specific class in the hierarchy.
"""
_, _, _, local_vars = inspect.getargvalues(frame)
# frame.f_back must be of a type types.FrameType for get_init_args/collect_init_args due to mypy
if not isinstance(frame.f_back, types.FrameType):
return path_args
if '__class__' in local_vars:
local_args = get_init_args(frame)
# recursive update
@ -172,7 +189,7 @@ def collect_init_args(frame, path_args: list, inside: bool = False) -> list:
return path_args
def flatten_dict(source, result=None):
def flatten_dict(source: Dict[str, Any], result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
if result is None:
result = {}
@ -187,7 +204,7 @@ def flatten_dict(source, result=None):
def save_hyperparameters(
obj: Any,
*args,
*args: Any,
ignore: Optional[Union[Sequence[str], str]] = None,
frame: Optional[types.FrameType] = None
) -> None:
@ -198,7 +215,12 @@ def save_hyperparameters(
return
if not frame:
frame = inspect.currentframe().f_back
current_frame = inspect.currentframe()
# inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available
if current_frame:
frame = current_frame.f_back
if not isinstance(frame, types.FrameType):
raise AttributeError("There is no `frame` available while being required.")
if is_dataclass(obj):
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
@ -251,16 +273,16 @@ class AttributeDict(Dict):
"my-key": 3.14
"""
def __getattr__(self, key):
def __getattr__(self, key: str) -> Optional[Any]:
try:
return self[key]
except KeyError as exp:
raise AttributeError(f'Missing attribute "{key}"') from exp
def __setattr__(self, key, val):
def __setattr__(self, key: str, val: Any) -> None:
self[key] = val
def __repr__(self):
def __repr__(self) -> str:
if not len(self):
return ""
max_key_length = max([len(str(k)) for k in self])
@ -270,14 +292,14 @@ class AttributeDict(Dict):
return out
def _lightning_get_all_attr_holders(model, attribute):
def _lightning_get_all_attr_holders(model: 'pl.LightningModule', attribute: str) -> List[Any]:
"""
Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute.
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule.
"""
trainer = getattr(model, 'trainer', None)
holders = []
holders: List[Any] = []
# Check if attribute in model
if hasattr(model, attribute):
@ -295,7 +317,7 @@ def _lightning_get_all_attr_holders(model, attribute):
return holders
def _lightning_get_first_attr_holder(model, attribute):
def _lightning_get_first_attr_holder(model: 'pl.LightningModule', attribute: str) -> Optional[Any]:
"""
Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None.
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule,
@ -308,7 +330,7 @@ def _lightning_get_first_attr_holder(model, attribute):
return holders[-1]
def lightning_hasattr(model, attribute):
def lightning_hasattr(model: 'pl.LightningModule', attribute: str) -> bool:
"""
Special hasattr for Lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule.
@ -316,7 +338,7 @@ def lightning_hasattr(model, attribute):
return _lightning_get_first_attr_holder(model, attribute) is not None
def lightning_getattr(model, attribute):
def lightning_getattr(model: 'pl.LightningModule', attribute: str) -> Optional[Any]:
"""
Special getattr for Lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule.
@ -338,7 +360,7 @@ def lightning_getattr(model, attribute):
return getattr(holder, attribute)
def lightning_setattr(model, attribute, value):
def lightning_setattr(model: 'pl.LightningModule', attribute: str, value: Any) -> None:
"""
Special setattr for Lightning. Checks for attribute in model namespace
and the old hparams namespace/dict.

View File

@ -185,6 +185,8 @@ ignore_errors = True
ignore_errors = False
[mypy-pytorch_lightning.utilities.device_parser]
ignore_errors = False
[mypy-pytorch_lightning.utilities.parsing]
ignore_errors = False
# todo: add proper typing to this module...
[mypy-pl_examples.*]