ref: separate argparse (#3428)

This commit is contained in:
William Falcon 2020-09-09 18:07:56 -04:00 committed by GitHub
parent f7dac3ff6c
commit 9696484153
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 115 additions and 106 deletions

View File

@ -488,112 +488,7 @@ class Trainer(
@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
r"""Extends existing argparse by default `Trainer` attributes.
Args:
parent_parser:
The custom cli arguments parser, which will be extended by
the Trainer default arguments.
Only arguments of the allowed types (str, float, int, bool) will
extend the `parent_parser`.
Examples:
>>> import argparse
>>> import pprint
>>> parser = argparse.ArgumentParser()
>>> parser = Trainer.add_argparse_args(parser)
>>> args = parser.parse_args([])
>>> pprint.pprint(vars(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
{...
'check_val_every_n_epoch': 1,
'checkpoint_callback': True,
'default_root_dir': None,
'deterministic': False,
'distributed_backend': None,
'early_stop_callback': False,
...
'logger': True,
'max_epochs': 1000,
'max_steps': None,
'min_epochs': 1,
'min_steps': None,
...
'profiler': None,
'progress_bar_refresh_rate': 1,
...}
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False,)
blacklist = ['kwargs']
depr_arg_names = cls.get_deprecated_arg_names() + blacklist
allowed_types = (str, int, float, bool)
# TODO: get "help" from docstring :)
for arg, arg_types, arg_default in (
at for at in argparse_utils.get_init_arguments_and_types(cls) if at[0] not in depr_arg_names
):
arg_types = [at for at in allowed_types if at in arg_types]
if not arg_types:
# skip argument with not supported type
continue
arg_kwargs = {}
if bool in arg_types:
arg_kwargs.update(nargs="?", const=True)
# if the only arg type is bool
if len(arg_types) == 1:
use_type = parsing.str_to_bool
# if only two args (str, bool)
elif len(arg_types) == 2 and set(arg_types) == {str, bool}:
use_type = parsing.str_to_bool_or_str
else:
# filter out the bool as we need to use more general
use_type = [at for at in arg_types if at is not bool][0]
else:
use_type = arg_types[0]
if arg == 'gpus' or arg == 'tpu_cores':
use_type = Trainer._gpus_allowed_type
arg_default = Trainer._gpus_arg_default
# hack for types in (int, float)
if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types):
use_type = Trainer._int_or_float_type
# hack for track_grad_norm
if arg == 'track_grad_norm':
use_type = float
parser.add_argument(
f'--{arg}',
dest=arg,
default=arg_default,
type=use_type,
help='autogenerated by pl.Trainer',
**arg_kwargs,
)
return parser
def _gpus_allowed_type(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)
def _gpus_arg_default(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)
def _int_or_float_type(x) -> Union[int, float]:
if '.' in str(x):
return float(x)
else:
return int(x)
return argparse_utils.add_argparse_args(cls, parent_parser)
@property
def num_gpus(self) -> int:

View File

@ -1,6 +1,7 @@
import inspect
from argparse import ArgumentParser, Namespace
from typing import Union, List, Tuple, Any
from pytorch_lightning.utilities import parsing
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
@ -107,3 +108,116 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
name_type_default.append((arg, arg_types, arg_default))
return name_type_default
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
r"""Extends existing argparse by default `Trainer` attributes.
Args:
parent_parser:
The custom cli arguments parser, which will be extended by
the Trainer default arguments.
Only arguments of the allowed types (str, float, int, bool) will
extend the `parent_parser`.
Examples:
>>> import argparse
>>> import pprint
>>> from pytorch_lightning import Trainer
>>> parser = argparse.ArgumentParser()
>>> parser = Trainer.add_argparse_args(parser)
>>> args = parser.parse_args([])
>>> pprint.pprint(vars(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
{...
'check_val_every_n_epoch': 1,
'checkpoint_callback': True,
'default_root_dir': None,
'deterministic': False,
'distributed_backend': None,
'early_stop_callback': False,
...
'logger': True,
'max_epochs': 1000,
'max_steps': None,
'min_epochs': 1,
'min_steps': None,
...
'profiler': None,
'progress_bar_refresh_rate': 1,
...}
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False,)
blacklist = ['kwargs']
depr_arg_names = cls.get_deprecated_arg_names() + blacklist
allowed_types = (str, int, float, bool)
# TODO: get "help" from docstring :)
for arg, arg_types, arg_default in (
at for at in get_init_arguments_and_types(cls) if at[0] not in depr_arg_names
):
arg_types = [at for at in allowed_types if at in arg_types]
if not arg_types:
# skip argument with not supported type
continue
arg_kwargs = {}
if bool in arg_types:
arg_kwargs.update(nargs="?", const=True)
# if the only arg type is bool
if len(arg_types) == 1:
use_type = parsing.str_to_bool
# if only two args (str, bool)
elif len(arg_types) == 2 and set(arg_types) == {str, bool}:
use_type = parsing.str_to_bool_or_str
else:
# filter out the bool as we need to use more general
use_type = [at for at in arg_types if at is not bool][0]
else:
use_type = arg_types[0]
if arg == 'gpus' or arg == 'tpu_cores':
use_type = _gpus_allowed_type
arg_default = _gpus_arg_default
# hack for types in (int, float)
if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types):
use_type = _int_or_float_type
# hack for track_grad_norm
if arg == 'track_grad_norm':
use_type = float
parser.add_argument(
f'--{arg}',
dest=arg,
default=arg_default,
type=use_type,
help='autogenerated by pl.Trainer',
**arg_kwargs,
)
return parser
def _gpus_allowed_type(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)
def _gpus_arg_default(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)
def _int_or_float_type(x) -> Union[int, float]:
if '.' in str(x):
return float(x)
else:
return int(x)