diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index c7b57fe3fd..bc64804d80 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -17,8 +17,11 @@ import pickle import types from argparse import Namespace from dataclasses import fields, is_dataclass -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing_extensions import Literal + +import pytorch_lightning as pl from pytorch_lightning.utilities.warnings import rank_zero_warn @@ -53,10 +56,10 @@ def str_to_bool(val: str) -> bool: >>> str_to_bool('FALSE') False """ - val = str_to_bool_or_str(val) - if isinstance(val, bool): - return val - raise ValueError(f'invalid truth value {val}') + val_converted = str_to_bool_or_str(val) + if isinstance(val_converted, bool): + return val_converted + raise ValueError(f'invalid truth value {val_converted}') def str_to_bool_or_int(val: str) -> Union[bool, int, str]: @@ -71,13 +74,13 @@ def str_to_bool_or_int(val: str) -> Union[bool, int, str]: >>> str_to_bool_or_int("abc") 'abc' """ - val = str_to_bool_or_str(val) - if isinstance(val, bool): - return val + val_converted = str_to_bool_or_str(val) + if isinstance(val_converted, bool): + return val_converted try: - return int(val) + return int(val_converted) except ValueError: - return val + return val_converted def is_picklable(obj: object) -> bool: @@ -90,7 +93,7 @@ def is_picklable(obj: object) -> bool: return False -def clean_namespace(hparams): +def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None: """Removes all unpicklable entries from hparams""" hparams_dict = hparams @@ -104,7 +107,7 @@ def clean_namespace(hparams): del hparams_dict[k] -def parse_class_init_keys(cls) -> Tuple[str, str, str]: +def parse_class_init_keys(cls: Type['pl.LightningModule']) -> Tuple[str, Optional[str], Optional[str]]: """Parse key words for standard self, *args and **kwargs >>> class Model(): @@ -120,10 +123,14 @@ def parse_class_init_keys(cls) -> Tuple[str, str, str]: # self is always first n_self = init_params[0].name - def _get_first_if_any(params, param_type): + def _get_first_if_any( + params: List[inspect.Parameter], + param_type: Literal[inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD], + ) -> Optional[str]: for p in params: if p.kind == param_type: return p.name + return None n_args = _get_first_if_any(init_params, inspect.Parameter.VAR_POSITIONAL) n_kwargs = _get_first_if_any(init_params, inspect.Parameter.VAR_KEYWORD) @@ -131,7 +138,7 @@ def parse_class_init_keys(cls) -> Tuple[str, str, str]: return n_self, n_args, n_kwargs -def get_init_args(frame) -> dict: +def get_init_args(frame: types.FrameType) -> Dict[str, Any]: _, _, _, local_vars = inspect.getargvalues(frame) if '__class__' not in local_vars: return {} @@ -142,12 +149,18 @@ def get_init_args(frame) -> dict: exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args') # only collect variables that appear in the signature local_args = {k: local_vars[k] for k in init_parameters.keys()} - local_args.update(local_args.get(kwargs_var, {})) + # kwargs_var might be None => raised an error by mypy + if kwargs_var: + local_args.update(local_args.get(kwargs_var, {})) local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames} return local_args -def collect_init_args(frame, path_args: list, inside: bool = False) -> list: +def collect_init_args( + frame: types.FrameType, + path_args: List[Dict[str, Any]], + inside: bool = False, +) -> List[Dict[str, Any]]: """ Recursively collects the arguments passed to the child constructors in the inheritance tree. @@ -162,6 +175,10 @@ def collect_init_args(frame, path_args: list, inside: bool = False) -> list: most specific class in the hierarchy. """ _, _, _, local_vars = inspect.getargvalues(frame) + # frame.f_back must be of a type types.FrameType for get_init_args/collect_init_args due to mypy + if not isinstance(frame.f_back, types.FrameType): + return path_args + if '__class__' in local_vars: local_args = get_init_args(frame) # recursive update @@ -172,7 +189,7 @@ def collect_init_args(frame, path_args: list, inside: bool = False) -> list: return path_args -def flatten_dict(source, result=None): +def flatten_dict(source: Dict[str, Any], result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: if result is None: result = {} @@ -187,7 +204,7 @@ def flatten_dict(source, result=None): def save_hyperparameters( obj: Any, - *args, + *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None ) -> None: @@ -198,7 +215,12 @@ def save_hyperparameters( return if not frame: - frame = inspect.currentframe().f_back + current_frame = inspect.currentframe() + # inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available + if current_frame: + frame = current_frame.f_back + if not isinstance(frame, types.FrameType): + raise AttributeError("There is no `frame` available while being required.") if is_dataclass(obj): init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} @@ -251,16 +273,16 @@ class AttributeDict(Dict): "my-key": 3.14 """ - def __getattr__(self, key): + def __getattr__(self, key: str) -> Optional[Any]: try: return self[key] except KeyError as exp: raise AttributeError(f'Missing attribute "{key}"') from exp - def __setattr__(self, key, val): + def __setattr__(self, key: str, val: Any) -> None: self[key] = val - def __repr__(self): + def __repr__(self) -> str: if not len(self): return "" max_key_length = max([len(str(k)) for k in self]) @@ -270,14 +292,14 @@ class AttributeDict(Dict): return out -def _lightning_get_all_attr_holders(model, attribute): +def _lightning_get_all_attr_holders(model: 'pl.LightningModule', attribute: str) -> List[Any]: """ Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """ trainer = getattr(model, 'trainer', None) - holders = [] + holders: List[Any] = [] # Check if attribute in model if hasattr(model, attribute): @@ -295,7 +317,7 @@ def _lightning_get_all_attr_holders(model, attribute): return holders -def _lightning_get_first_attr_holder(model, attribute): +def _lightning_get_first_attr_holder(model: 'pl.LightningModule', attribute: str) -> Optional[Any]: """ Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule, @@ -308,7 +330,7 @@ def _lightning_get_first_attr_holder(model, attribute): return holders[-1] -def lightning_hasattr(model, attribute): +def lightning_hasattr(model: 'pl.LightningModule', attribute: str) -> bool: """ Special hasattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. @@ -316,7 +338,7 @@ def lightning_hasattr(model, attribute): return _lightning_get_first_attr_holder(model, attribute) is not None -def lightning_getattr(model, attribute): +def lightning_getattr(model: 'pl.LightningModule', attribute: str) -> Optional[Any]: """ Special getattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. @@ -338,7 +360,7 @@ def lightning_getattr(model, attribute): return getattr(holder, attribute) -def lightning_setattr(model, attribute, value): +def lightning_setattr(model: 'pl.LightningModule', attribute: str, value: Any) -> None: """ Special setattr for Lightning. Checks for attribute in model namespace and the old hparams namespace/dict. diff --git a/setup.cfg b/setup.cfg index 6a1107ea9a..8dae1dc720 100644 --- a/setup.cfg +++ b/setup.cfg @@ -185,6 +185,8 @@ ignore_errors = True ignore_errors = False [mypy-pytorch_lightning.utilities.device_parser] ignore_errors = False +[mypy-pytorch_lightning.utilities.parsing] +ignore_errors = False # todo: add proper typing to this module... [mypy-pl_examples.*]