ref: separate argparse (#3428)
This commit is contained in:
parent
f7dac3ff6c
commit
9696484153
|
@ -488,112 +488,7 @@ class Trainer(
|
|||
|
||||
@classmethod
|
||||
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
|
||||
r"""Extends existing argparse by default `Trainer` attributes.
|
||||
|
||||
Args:
|
||||
parent_parser:
|
||||
The custom cli arguments parser, which will be extended by
|
||||
the Trainer default arguments.
|
||||
|
||||
Only arguments of the allowed types (str, float, int, bool) will
|
||||
extend the `parent_parser`.
|
||||
|
||||
Examples:
|
||||
>>> import argparse
|
||||
>>> import pprint
|
||||
>>> parser = argparse.ArgumentParser()
|
||||
>>> parser = Trainer.add_argparse_args(parser)
|
||||
>>> args = parser.parse_args([])
|
||||
>>> pprint.pprint(vars(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
||||
{...
|
||||
'check_val_every_n_epoch': 1,
|
||||
'checkpoint_callback': True,
|
||||
'default_root_dir': None,
|
||||
'deterministic': False,
|
||||
'distributed_backend': None,
|
||||
'early_stop_callback': False,
|
||||
...
|
||||
'logger': True,
|
||||
'max_epochs': 1000,
|
||||
'max_steps': None,
|
||||
'min_epochs': 1,
|
||||
'min_steps': None,
|
||||
...
|
||||
'profiler': None,
|
||||
'progress_bar_refresh_rate': 1,
|
||||
...}
|
||||
|
||||
"""
|
||||
parser = ArgumentParser(parents=[parent_parser], add_help=False,)
|
||||
|
||||
blacklist = ['kwargs']
|
||||
depr_arg_names = cls.get_deprecated_arg_names() + blacklist
|
||||
|
||||
allowed_types = (str, int, float, bool)
|
||||
|
||||
# TODO: get "help" from docstring :)
|
||||
for arg, arg_types, arg_default in (
|
||||
at for at in argparse_utils.get_init_arguments_and_types(cls) if at[0] not in depr_arg_names
|
||||
):
|
||||
arg_types = [at for at in allowed_types if at in arg_types]
|
||||
if not arg_types:
|
||||
# skip argument with not supported type
|
||||
continue
|
||||
arg_kwargs = {}
|
||||
if bool in arg_types:
|
||||
arg_kwargs.update(nargs="?", const=True)
|
||||
# if the only arg type is bool
|
||||
if len(arg_types) == 1:
|
||||
use_type = parsing.str_to_bool
|
||||
# if only two args (str, bool)
|
||||
elif len(arg_types) == 2 and set(arg_types) == {str, bool}:
|
||||
use_type = parsing.str_to_bool_or_str
|
||||
else:
|
||||
# filter out the bool as we need to use more general
|
||||
use_type = [at for at in arg_types if at is not bool][0]
|
||||
else:
|
||||
use_type = arg_types[0]
|
||||
|
||||
if arg == 'gpus' or arg == 'tpu_cores':
|
||||
use_type = Trainer._gpus_allowed_type
|
||||
arg_default = Trainer._gpus_arg_default
|
||||
|
||||
# hack for types in (int, float)
|
||||
if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types):
|
||||
use_type = Trainer._int_or_float_type
|
||||
|
||||
# hack for track_grad_norm
|
||||
if arg == 'track_grad_norm':
|
||||
use_type = float
|
||||
|
||||
parser.add_argument(
|
||||
f'--{arg}',
|
||||
dest=arg,
|
||||
default=arg_default,
|
||||
type=use_type,
|
||||
help='autogenerated by pl.Trainer',
|
||||
**arg_kwargs,
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
def _gpus_allowed_type(x) -> Union[int, str]:
|
||||
if ',' in x:
|
||||
return str(x)
|
||||
else:
|
||||
return int(x)
|
||||
|
||||
def _gpus_arg_default(x) -> Union[int, str]:
|
||||
if ',' in x:
|
||||
return str(x)
|
||||
else:
|
||||
return int(x)
|
||||
|
||||
def _int_or_float_type(x) -> Union[int, float]:
|
||||
if '.' in str(x):
|
||||
return float(x)
|
||||
else:
|
||||
return int(x)
|
||||
return argparse_utils.add_argparse_args(cls, parent_parser)
|
||||
|
||||
@property
|
||||
def num_gpus(self) -> int:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import inspect
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from typing import Union, List, Tuple, Any
|
||||
from pytorch_lightning.utilities import parsing
|
||||
|
||||
|
||||
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
|
||||
|
@ -107,3 +108,116 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
|
|||
name_type_default.append((arg, arg_types, arg_default))
|
||||
|
||||
return name_type_default
|
||||
|
||||
|
||||
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
|
||||
r"""Extends existing argparse by default `Trainer` attributes.
|
||||
|
||||
Args:
|
||||
parent_parser:
|
||||
The custom cli arguments parser, which will be extended by
|
||||
the Trainer default arguments.
|
||||
|
||||
Only arguments of the allowed types (str, float, int, bool) will
|
||||
extend the `parent_parser`.
|
||||
|
||||
Examples:
|
||||
>>> import argparse
|
||||
>>> import pprint
|
||||
>>> from pytorch_lightning import Trainer
|
||||
>>> parser = argparse.ArgumentParser()
|
||||
>>> parser = Trainer.add_argparse_args(parser)
|
||||
>>> args = parser.parse_args([])
|
||||
>>> pprint.pprint(vars(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
||||
{...
|
||||
'check_val_every_n_epoch': 1,
|
||||
'checkpoint_callback': True,
|
||||
'default_root_dir': None,
|
||||
'deterministic': False,
|
||||
'distributed_backend': None,
|
||||
'early_stop_callback': False,
|
||||
...
|
||||
'logger': True,
|
||||
'max_epochs': 1000,
|
||||
'max_steps': None,
|
||||
'min_epochs': 1,
|
||||
'min_steps': None,
|
||||
...
|
||||
'profiler': None,
|
||||
'progress_bar_refresh_rate': 1,
|
||||
...}
|
||||
|
||||
"""
|
||||
parser = ArgumentParser(parents=[parent_parser], add_help=False,)
|
||||
|
||||
blacklist = ['kwargs']
|
||||
depr_arg_names = cls.get_deprecated_arg_names() + blacklist
|
||||
|
||||
allowed_types = (str, int, float, bool)
|
||||
|
||||
# TODO: get "help" from docstring :)
|
||||
for arg, arg_types, arg_default in (
|
||||
at for at in get_init_arguments_and_types(cls) if at[0] not in depr_arg_names
|
||||
):
|
||||
arg_types = [at for at in allowed_types if at in arg_types]
|
||||
if not arg_types:
|
||||
# skip argument with not supported type
|
||||
continue
|
||||
arg_kwargs = {}
|
||||
if bool in arg_types:
|
||||
arg_kwargs.update(nargs="?", const=True)
|
||||
# if the only arg type is bool
|
||||
if len(arg_types) == 1:
|
||||
use_type = parsing.str_to_bool
|
||||
# if only two args (str, bool)
|
||||
elif len(arg_types) == 2 and set(arg_types) == {str, bool}:
|
||||
use_type = parsing.str_to_bool_or_str
|
||||
else:
|
||||
# filter out the bool as we need to use more general
|
||||
use_type = [at for at in arg_types if at is not bool][0]
|
||||
else:
|
||||
use_type = arg_types[0]
|
||||
|
||||
if arg == 'gpus' or arg == 'tpu_cores':
|
||||
use_type = _gpus_allowed_type
|
||||
arg_default = _gpus_arg_default
|
||||
|
||||
# hack for types in (int, float)
|
||||
if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types):
|
||||
use_type = _int_or_float_type
|
||||
|
||||
# hack for track_grad_norm
|
||||
if arg == 'track_grad_norm':
|
||||
use_type = float
|
||||
|
||||
parser.add_argument(
|
||||
f'--{arg}',
|
||||
dest=arg,
|
||||
default=arg_default,
|
||||
type=use_type,
|
||||
help='autogenerated by pl.Trainer',
|
||||
**arg_kwargs,
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _gpus_allowed_type(x) -> Union[int, str]:
|
||||
if ',' in x:
|
||||
return str(x)
|
||||
else:
|
||||
return int(x)
|
||||
|
||||
|
||||
def _gpus_arg_default(x) -> Union[int, str]:
|
||||
if ',' in x:
|
||||
return str(x)
|
||||
else:
|
||||
return int(x)
|
||||
|
||||
|
||||
def _int_or_float_type(x) -> Union[int, float]:
|
||||
if '.' in str(x):
|
||||
return float(x)
|
||||
else:
|
||||
return int(x)
|
||||
|
|
Loading…
Reference in New Issue