diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5000a716bc..12f2064d6c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -488,112 +488,7 @@ class Trainer( @classmethod def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: - r"""Extends existing argparse by default `Trainer` attributes. - - Args: - parent_parser: - The custom cli arguments parser, which will be extended by - the Trainer default arguments. - - Only arguments of the allowed types (str, float, int, bool) will - extend the `parent_parser`. - - Examples: - >>> import argparse - >>> import pprint - >>> parser = argparse.ArgumentParser() - >>> parser = Trainer.add_argparse_args(parser) - >>> args = parser.parse_args([]) - >>> pprint.pprint(vars(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - {... - 'check_val_every_n_epoch': 1, - 'checkpoint_callback': True, - 'default_root_dir': None, - 'deterministic': False, - 'distributed_backend': None, - 'early_stop_callback': False, - ... - 'logger': True, - 'max_epochs': 1000, - 'max_steps': None, - 'min_epochs': 1, - 'min_steps': None, - ... - 'profiler': None, - 'progress_bar_refresh_rate': 1, - ...} - - """ - parser = ArgumentParser(parents=[parent_parser], add_help=False,) - - blacklist = ['kwargs'] - depr_arg_names = cls.get_deprecated_arg_names() + blacklist - - allowed_types = (str, int, float, bool) - - # TODO: get "help" from docstring :) - for arg, arg_types, arg_default in ( - at for at in argparse_utils.get_init_arguments_and_types(cls) if at[0] not in depr_arg_names - ): - arg_types = [at for at in allowed_types if at in arg_types] - if not arg_types: - # skip argument with not supported type - continue - arg_kwargs = {} - if bool in arg_types: - arg_kwargs.update(nargs="?", const=True) - # if the only arg type is bool - if len(arg_types) == 1: - use_type = parsing.str_to_bool - # if only two args (str, bool) - elif len(arg_types) == 2 and set(arg_types) == {str, bool}: - use_type = parsing.str_to_bool_or_str - else: - # filter out the bool as we need to use more general - use_type = [at for at in arg_types if at is not bool][0] - else: - use_type = arg_types[0] - - if arg == 'gpus' or arg == 'tpu_cores': - use_type = Trainer._gpus_allowed_type - arg_default = Trainer._gpus_arg_default - - # hack for types in (int, float) - if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types): - use_type = Trainer._int_or_float_type - - # hack for track_grad_norm - if arg == 'track_grad_norm': - use_type = float - - parser.add_argument( - f'--{arg}', - dest=arg, - default=arg_default, - type=use_type, - help='autogenerated by pl.Trainer', - **arg_kwargs, - ) - - return parser - - def _gpus_allowed_type(x) -> Union[int, str]: - if ',' in x: - return str(x) - else: - return int(x) - - def _gpus_arg_default(x) -> Union[int, str]: - if ',' in x: - return str(x) - else: - return int(x) - - def _int_or_float_type(x) -> Union[int, float]: - if '.' in str(x): - return float(x) - else: - return int(x) + return argparse_utils.add_argparse_args(cls, parent_parser) @property def num_gpus(self) -> int: diff --git a/pytorch_lightning/utilities/argparse_utils.py b/pytorch_lightning/utilities/argparse_utils.py index 2caec9ef27..15a25da889 100644 --- a/pytorch_lightning/utilities/argparse_utils.py +++ b/pytorch_lightning/utilities/argparse_utils.py @@ -1,6 +1,7 @@ import inspect from argparse import ArgumentParser, Namespace from typing import Union, List, Tuple, Any +from pytorch_lightning.utilities import parsing def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): @@ -107,3 +108,116 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: name_type_default.append((arg, arg_types, arg_default)) return name_type_default + + +def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: + r"""Extends existing argparse by default `Trainer` attributes. + + Args: + parent_parser: + The custom cli arguments parser, which will be extended by + the Trainer default arguments. + + Only arguments of the allowed types (str, float, int, bool) will + extend the `parent_parser`. + + Examples: + >>> import argparse + >>> import pprint + >>> from pytorch_lightning import Trainer + >>> parser = argparse.ArgumentParser() + >>> parser = Trainer.add_argparse_args(parser) + >>> args = parser.parse_args([]) + >>> pprint.pprint(vars(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + {... + 'check_val_every_n_epoch': 1, + 'checkpoint_callback': True, + 'default_root_dir': None, + 'deterministic': False, + 'distributed_backend': None, + 'early_stop_callback': False, + ... + 'logger': True, + 'max_epochs': 1000, + 'max_steps': None, + 'min_epochs': 1, + 'min_steps': None, + ... + 'profiler': None, + 'progress_bar_refresh_rate': 1, + ...} + + """ + parser = ArgumentParser(parents=[parent_parser], add_help=False,) + + blacklist = ['kwargs'] + depr_arg_names = cls.get_deprecated_arg_names() + blacklist + + allowed_types = (str, int, float, bool) + + # TODO: get "help" from docstring :) + for arg, arg_types, arg_default in ( + at for at in get_init_arguments_and_types(cls) if at[0] not in depr_arg_names + ): + arg_types = [at for at in allowed_types if at in arg_types] + if not arg_types: + # skip argument with not supported type + continue + arg_kwargs = {} + if bool in arg_types: + arg_kwargs.update(nargs="?", const=True) + # if the only arg type is bool + if len(arg_types) == 1: + use_type = parsing.str_to_bool + # if only two args (str, bool) + elif len(arg_types) == 2 and set(arg_types) == {str, bool}: + use_type = parsing.str_to_bool_or_str + else: + # filter out the bool as we need to use more general + use_type = [at for at in arg_types if at is not bool][0] + else: + use_type = arg_types[0] + + if arg == 'gpus' or arg == 'tpu_cores': + use_type = _gpus_allowed_type + arg_default = _gpus_arg_default + + # hack for types in (int, float) + if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types): + use_type = _int_or_float_type + + # hack for track_grad_norm + if arg == 'track_grad_norm': + use_type = float + + parser.add_argument( + f'--{arg}', + dest=arg, + default=arg_default, + type=use_type, + help='autogenerated by pl.Trainer', + **arg_kwargs, + ) + + return parser + + +def _gpus_allowed_type(x) -> Union[int, str]: + if ',' in x: + return str(x) + else: + return int(x) + + +def _gpus_arg_default(x) -> Union[int, str]: + if ',' in x: + return str(x) + else: + return int(x) + + +def _int_or_float_type(x) -> Union[int, float]: + if '.' in str(x): + return float(x) + else: + return int(x)