extend arg parser (#1842)
* extend arg parser * flake8 * tests * example * fix test
This commit is contained in:
parent
a6f6edd07d
commit
bee0392c37
|
@ -1,7 +1,7 @@
|
|||
import inspect
|
||||
import os
|
||||
import logging as python_logging
|
||||
from argparse import ArgumentParser
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any
|
||||
|
||||
import torch
|
||||
|
@ -132,7 +132,7 @@ class Trainer(
|
|||
replace_sampler_ddp: bool = True,
|
||||
progress_bar_callback: Optional[Union[ProgressBarBase, bool]] = True,
|
||||
terminate_on_nan: bool = False,
|
||||
auto_scale_batch_size: Optional[str] = None,
|
||||
auto_scale_batch_size: Union[str, bool] = False,
|
||||
amp_level: str = 'O1', # backward compatible, todo: remove in v0.8.0
|
||||
default_save_path=None, # backward compatible, todo: remove in v0.8.0
|
||||
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
|
||||
|
@ -663,52 +663,70 @@ class Trainer(
|
|||
# TODO: get "help" from docstring :)
|
||||
for arg, arg_types, arg_default in (at for at in cls.get_init_arguments_and_types()
|
||||
if at[0] not in depr_arg_names):
|
||||
|
||||
for allowed_type in (at for at in allowed_types if at in arg_types):
|
||||
if allowed_type is bool:
|
||||
def allowed_type(x):
|
||||
arg_types = [at for at in allowed_types if at in arg_types]
|
||||
if not arg_types:
|
||||
# skip argument with not supported type
|
||||
continue
|
||||
arg_kwargs = {}
|
||||
if bool in arg_types:
|
||||
arg_kwargs.update(nargs="?")
|
||||
# if the only arg type is bool
|
||||
if len(arg_types) == 1:
|
||||
# redefine the type for ArgParser needed
|
||||
def use_type(x):
|
||||
return bool(parsing.strtobool(x))
|
||||
else:
|
||||
# filter out the bool as we need to use more general
|
||||
use_type = [at for at in arg_types if at is not bool][0]
|
||||
else:
|
||||
use_type = arg_types[0]
|
||||
|
||||
# Bool args with default of True parsed as flags not key value pair
|
||||
if arg_types == (bool,) and arg_default is False:
|
||||
parser.add_argument(
|
||||
f'--{arg}',
|
||||
action='store_true',
|
||||
dest=arg,
|
||||
help='autogenerated by pl.Trainer'
|
||||
)
|
||||
continue
|
||||
if arg == 'gpus':
|
||||
use_type = Trainer._allowed_type
|
||||
arg_default = Trainer._arg_default
|
||||
|
||||
if arg == 'gpus':
|
||||
allowed_type = Trainer.allowed_type
|
||||
arg_default = Trainer.arg_default
|
||||
|
||||
parser.add_argument(
|
||||
f'--{arg}',
|
||||
default=arg_default,
|
||||
type=allowed_type,
|
||||
dest=arg,
|
||||
help='autogenerated by pl.Trainer'
|
||||
)
|
||||
break
|
||||
parser.add_argument(
|
||||
f'--{arg}',
|
||||
dest=arg,
|
||||
default=arg_default,
|
||||
type=use_type,
|
||||
help='autogenerated by pl.Trainer',
|
||||
**arg_kwargs,
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
def allowed_type(x):
|
||||
def _allowed_type(x) -> Union[int, str]:
|
||||
if ',' in x:
|
||||
return str(x)
|
||||
else:
|
||||
return int(x)
|
||||
|
||||
def arg_default(x):
|
||||
def _arg_default(x) -> Union[int, str]:
|
||||
if ',' in x:
|
||||
return str(x)
|
||||
else:
|
||||
return int(x)
|
||||
|
||||
@staticmethod
|
||||
def parse_argparser(arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
|
||||
"""Parse CLI arguments, required for custom bool types."""
|
||||
args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser
|
||||
args = {k: True if v is None else v for k, v in vars(args).items()}
|
||||
return Namespace(**args)
|
||||
|
||||
@classmethod
|
||||
def from_argparse_args(cls, args, **kwargs):
|
||||
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer':
|
||||
"""create an instance from CLI arguments
|
||||
|
||||
Example:
|
||||
>>> parser = ArgumentParser(add_help=False)
|
||||
>>> parser = Trainer.add_argparse_args(parser)
|
||||
>>> args = Trainer.parse_argparser(parser.parse_args(""))
|
||||
>>> trainer = Trainer.from_argparse_args(args)
|
||||
"""
|
||||
if isinstance(args, ArgumentParser):
|
||||
args = Trainer.parse_argparser(args)
|
||||
params = vars(args)
|
||||
params.update(**kwargs)
|
||||
|
||||
|
@ -797,6 +815,8 @@ class Trainer(
|
|||
|
||||
# Run auto batch size scaling
|
||||
if self.auto_scale_batch_size:
|
||||
if isinstance(self.auto_scale_batch_size, bool):
|
||||
self.auto_scale_batch_size = 'power'
|
||||
self.scale_batch_size(model, mode=self.auto_scale_batch_size)
|
||||
|
||||
# Run learning rate finder:
|
||||
|
|
|
@ -88,3 +88,25 @@ def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
|
|||
|
||||
with pytest.raises(_UnkArgError):
|
||||
parser.parse_args(cli_args)
|
||||
|
||||
|
||||
# todo: add also testing for "gpus"
|
||||
@pytest.mark.parametrize(['cli_args', 'expected'], [
|
||||
pytest.param('--auto_lr_find --auto_scale_batch_size power',
|
||||
{'auto_lr_find': True, 'auto_scale_batch_size': 'power', 'early_stop_callback': False}),
|
||||
pytest.param('--auto_lr_find any_string --auto_scale_batch_size',
|
||||
{'auto_lr_find': 'any_string', 'auto_scale_batch_size': True}),
|
||||
pytest.param('--early_stop_callback',
|
||||
{'auto_lr_find': False, 'early_stop_callback': True, 'auto_scale_batch_size': False}),
|
||||
])
|
||||
def test_argparse_args_parsing(cli_args, expected):
|
||||
"""Test multi type argument with bool."""
|
||||
cli_args = cli_args.split(' ') if cli_args else []
|
||||
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
|
||||
parser = ArgumentParser(add_help=False)
|
||||
parser = Trainer.add_argparse_args(parent_parser=parser)
|
||||
args = Trainer.parse_argparser(parser)
|
||||
|
||||
for k, v in expected.items():
|
||||
assert getattr(args, k) == v
|
||||
assert Trainer.from_argparse_args(args)
|
||||
|
|
Loading…
Reference in New Issue