Fix mypy in `utilities.parsing` (#8132)

This commit is contained in:
Daniel Stancl 2021-07-08 01:32:12 +02:00 committed by GitHub
parent 9bbca402ff
commit 667def8d89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 28 deletions

View File

@ -17,8 +17,11 @@ import pickle
import types import types
from argparse import Namespace from argparse import Namespace
from dataclasses import fields, is_dataclass from dataclasses import fields, is_dataclass
from typing import Any, Dict, Optional, Sequence, Tuple, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing_extensions import Literal
import pytorch_lightning as pl
from pytorch_lightning.utilities.warnings import rank_zero_warn from pytorch_lightning.utilities.warnings import rank_zero_warn
@ -53,10 +56,10 @@ def str_to_bool(val: str) -> bool:
>>> str_to_bool('FALSE') >>> str_to_bool('FALSE')
False False
""" """
val = str_to_bool_or_str(val) val_converted = str_to_bool_or_str(val)
if isinstance(val, bool): if isinstance(val_converted, bool):
return val return val_converted
raise ValueError(f'invalid truth value {val}') raise ValueError(f'invalid truth value {val_converted}')
def str_to_bool_or_int(val: str) -> Union[bool, int, str]: def str_to_bool_or_int(val: str) -> Union[bool, int, str]:
@ -71,13 +74,13 @@ def str_to_bool_or_int(val: str) -> Union[bool, int, str]:
>>> str_to_bool_or_int("abc") >>> str_to_bool_or_int("abc")
'abc' 'abc'
""" """
val = str_to_bool_or_str(val) val_converted = str_to_bool_or_str(val)
if isinstance(val, bool): if isinstance(val_converted, bool):
return val return val_converted
try: try:
return int(val) return int(val_converted)
except ValueError: except ValueError:
return val return val_converted
def is_picklable(obj: object) -> bool: def is_picklable(obj: object) -> bool:
@ -90,7 +93,7 @@ def is_picklable(obj: object) -> bool:
return False return False
def clean_namespace(hparams): def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None:
"""Removes all unpicklable entries from hparams""" """Removes all unpicklable entries from hparams"""
hparams_dict = hparams hparams_dict = hparams
@ -104,7 +107,7 @@ def clean_namespace(hparams):
del hparams_dict[k] del hparams_dict[k]
def parse_class_init_keys(cls) -> Tuple[str, str, str]: def parse_class_init_keys(cls: Type['pl.LightningModule']) -> Tuple[str, Optional[str], Optional[str]]:
"""Parse key words for standard self, *args and **kwargs """Parse key words for standard self, *args and **kwargs
>>> class Model(): >>> class Model():
@ -120,10 +123,14 @@ def parse_class_init_keys(cls) -> Tuple[str, str, str]:
# self is always first # self is always first
n_self = init_params[0].name n_self = init_params[0].name
def _get_first_if_any(params, param_type): def _get_first_if_any(
params: List[inspect.Parameter],
param_type: Literal[inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD],
) -> Optional[str]:
for p in params: for p in params:
if p.kind == param_type: if p.kind == param_type:
return p.name return p.name
return None
n_args = _get_first_if_any(init_params, inspect.Parameter.VAR_POSITIONAL) n_args = _get_first_if_any(init_params, inspect.Parameter.VAR_POSITIONAL)
n_kwargs = _get_first_if_any(init_params, inspect.Parameter.VAR_KEYWORD) n_kwargs = _get_first_if_any(init_params, inspect.Parameter.VAR_KEYWORD)
@ -131,7 +138,7 @@ def parse_class_init_keys(cls) -> Tuple[str, str, str]:
return n_self, n_args, n_kwargs return n_self, n_args, n_kwargs
def get_init_args(frame) -> dict: def get_init_args(frame: types.FrameType) -> Dict[str, Any]:
_, _, _, local_vars = inspect.getargvalues(frame) _, _, _, local_vars = inspect.getargvalues(frame)
if '__class__' not in local_vars: if '__class__' not in local_vars:
return {} return {}
@ -142,12 +149,18 @@ def get_init_args(frame) -> dict:
exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args') exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args')
# only collect variables that appear in the signature # only collect variables that appear in the signature
local_args = {k: local_vars[k] for k in init_parameters.keys()} local_args = {k: local_vars[k] for k in init_parameters.keys()}
local_args.update(local_args.get(kwargs_var, {})) # kwargs_var might be None => raised an error by mypy
if kwargs_var:
local_args.update(local_args.get(kwargs_var, {}))
local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames} local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames}
return local_args return local_args
def collect_init_args(frame, path_args: list, inside: bool = False) -> list: def collect_init_args(
frame: types.FrameType,
path_args: List[Dict[str, Any]],
inside: bool = False,
) -> List[Dict[str, Any]]:
""" """
Recursively collects the arguments passed to the child constructors in the inheritance tree. Recursively collects the arguments passed to the child constructors in the inheritance tree.
@ -162,6 +175,10 @@ def collect_init_args(frame, path_args: list, inside: bool = False) -> list:
most specific class in the hierarchy. most specific class in the hierarchy.
""" """
_, _, _, local_vars = inspect.getargvalues(frame) _, _, _, local_vars = inspect.getargvalues(frame)
# frame.f_back must be of a type types.FrameType for get_init_args/collect_init_args due to mypy
if not isinstance(frame.f_back, types.FrameType):
return path_args
if '__class__' in local_vars: if '__class__' in local_vars:
local_args = get_init_args(frame) local_args = get_init_args(frame)
# recursive update # recursive update
@ -172,7 +189,7 @@ def collect_init_args(frame, path_args: list, inside: bool = False) -> list:
return path_args return path_args
def flatten_dict(source, result=None): def flatten_dict(source: Dict[str, Any], result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
if result is None: if result is None:
result = {} result = {}
@ -187,7 +204,7 @@ def flatten_dict(source, result=None):
def save_hyperparameters( def save_hyperparameters(
obj: Any, obj: Any,
*args, *args: Any,
ignore: Optional[Union[Sequence[str], str]] = None, ignore: Optional[Union[Sequence[str], str]] = None,
frame: Optional[types.FrameType] = None frame: Optional[types.FrameType] = None
) -> None: ) -> None:
@ -198,7 +215,12 @@ def save_hyperparameters(
return return
if not frame: if not frame:
frame = inspect.currentframe().f_back current_frame = inspect.currentframe()
# inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available
if current_frame:
frame = current_frame.f_back
if not isinstance(frame, types.FrameType):
raise AttributeError("There is no `frame` available while being required.")
if is_dataclass(obj): if is_dataclass(obj):
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
@ -251,16 +273,16 @@ class AttributeDict(Dict):
"my-key": 3.14 "my-key": 3.14
""" """
def __getattr__(self, key): def __getattr__(self, key: str) -> Optional[Any]:
try: try:
return self[key] return self[key]
except KeyError as exp: except KeyError as exp:
raise AttributeError(f'Missing attribute "{key}"') from exp raise AttributeError(f'Missing attribute "{key}"') from exp
def __setattr__(self, key, val): def __setattr__(self, key: str, val: Any) -> None:
self[key] = val self[key] = val
def __repr__(self): def __repr__(self) -> str:
if not len(self): if not len(self):
return "" return ""
max_key_length = max([len(str(k)) for k in self]) max_key_length = max([len(str(k)) for k in self])
@ -270,14 +292,14 @@ class AttributeDict(Dict):
return out return out
def _lightning_get_all_attr_holders(model, attribute): def _lightning_get_all_attr_holders(model: 'pl.LightningModule', attribute: str) -> List[Any]:
""" """
Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute. Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute.
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule.
""" """
trainer = getattr(model, 'trainer', None) trainer = getattr(model, 'trainer', None)
holders = [] holders: List[Any] = []
# Check if attribute in model # Check if attribute in model
if hasattr(model, attribute): if hasattr(model, attribute):
@ -295,7 +317,7 @@ def _lightning_get_all_attr_holders(model, attribute):
return holders return holders
def _lightning_get_first_attr_holder(model, attribute): def _lightning_get_first_attr_holder(model: 'pl.LightningModule', attribute: str) -> Optional[Any]:
""" """
Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None. Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None.
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule, Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule,
@ -308,7 +330,7 @@ def _lightning_get_first_attr_holder(model, attribute):
return holders[-1] return holders[-1]
def lightning_hasattr(model, attribute): def lightning_hasattr(model: 'pl.LightningModule', attribute: str) -> bool:
""" """
Special hasattr for Lightning. Checks for attribute in model namespace, Special hasattr for Lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. the old hparams namespace/dict, and the datamodule.
@ -316,7 +338,7 @@ def lightning_hasattr(model, attribute):
return _lightning_get_first_attr_holder(model, attribute) is not None return _lightning_get_first_attr_holder(model, attribute) is not None
def lightning_getattr(model, attribute): def lightning_getattr(model: 'pl.LightningModule', attribute: str) -> Optional[Any]:
""" """
Special getattr for Lightning. Checks for attribute in model namespace, Special getattr for Lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. the old hparams namespace/dict, and the datamodule.
@ -338,7 +360,7 @@ def lightning_getattr(model, attribute):
return getattr(holder, attribute) return getattr(holder, attribute)
def lightning_setattr(model, attribute, value): def lightning_setattr(model: 'pl.LightningModule', attribute: str, value: Any) -> None:
""" """
Special setattr for Lightning. Checks for attribute in model namespace Special setattr for Lightning. Checks for attribute in model namespace
and the old hparams namespace/dict. and the old hparams namespace/dict.

View File

@ -185,6 +185,8 @@ ignore_errors = True
ignore_errors = False ignore_errors = False
[mypy-pytorch_lightning.utilities.device_parser] [mypy-pytorch_lightning.utilities.device_parser]
ignore_errors = False ignore_errors = False
[mypy-pytorch_lightning.utilities.parsing]
ignore_errors = False
# todo: add proper typing to this module... # todo: add proper typing to this module...
[mypy-pl_examples.*] [mypy-pl_examples.*]