extend arg parser (#1842)

* extend arg parser

* flake8

* tests

* example

* fix test
This commit is contained in:
Jirka Borovec 2020-05-14 23:56:11 +02:00 committed by GitHub
parent a6f6edd07d
commit bee0392c37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 72 additions and 30 deletions

View File

@ -1,7 +1,7 @@
import inspect
import os
import logging as python_logging
from argparse import ArgumentParser
from argparse import ArgumentParser, Namespace
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any
import torch
@ -132,7 +132,7 @@ class Trainer(
replace_sampler_ddp: bool = True,
progress_bar_callback: Optional[Union[ProgressBarBase, bool]] = True,
terminate_on_nan: bool = False,
auto_scale_batch_size: Optional[str] = None,
auto_scale_batch_size: Union[str, bool] = False,
amp_level: str = 'O1', # backward compatible, todo: remove in v0.8.0
default_save_path=None, # backward compatible, todo: remove in v0.8.0
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
@ -663,52 +663,70 @@ class Trainer(
# TODO: get "help" from docstring :)
for arg, arg_types, arg_default in (at for at in cls.get_init_arguments_and_types()
if at[0] not in depr_arg_names):
for allowed_type in (at for at in allowed_types if at in arg_types):
if allowed_type is bool:
def allowed_type(x):
arg_types = [at for at in allowed_types if at in arg_types]
if not arg_types:
# skip argument with not supported type
continue
arg_kwargs = {}
if bool in arg_types:
arg_kwargs.update(nargs="?")
# if the only arg type is bool
if len(arg_types) == 1:
# redefine the type for ArgParser needed
def use_type(x):
return bool(parsing.strtobool(x))
else:
# filter out the bool as we need to use more general
use_type = [at for at in arg_types if at is not bool][0]
else:
use_type = arg_types[0]
# Bool args with default of True parsed as flags not key value pair
if arg_types == (bool,) and arg_default is False:
parser.add_argument(
f'--{arg}',
action='store_true',
dest=arg,
help='autogenerated by pl.Trainer'
)
continue
if arg == 'gpus':
use_type = Trainer._allowed_type
arg_default = Trainer._arg_default
if arg == 'gpus':
allowed_type = Trainer.allowed_type
arg_default = Trainer.arg_default
parser.add_argument(
f'--{arg}',
default=arg_default,
type=allowed_type,
dest=arg,
help='autogenerated by pl.Trainer'
)
break
parser.add_argument(
f'--{arg}',
dest=arg,
default=arg_default,
type=use_type,
help='autogenerated by pl.Trainer',
**arg_kwargs,
)
return parser
def allowed_type(x):
def _allowed_type(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)
def arg_default(x):
def _arg_default(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)
@staticmethod
def parse_argparser(arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
"""Parse CLI arguments, required for custom bool types."""
args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser
args = {k: True if v is None else v for k, v in vars(args).items()}
return Namespace(**args)
@classmethod
def from_argparse_args(cls, args, **kwargs):
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer':
"""create an instance from CLI arguments
Example:
>>> parser = ArgumentParser(add_help=False)
>>> parser = Trainer.add_argparse_args(parser)
>>> args = Trainer.parse_argparser(parser.parse_args(""))
>>> trainer = Trainer.from_argparse_args(args)
"""
if isinstance(args, ArgumentParser):
args = Trainer.parse_argparser(args)
params = vars(args)
params.update(**kwargs)
@ -797,6 +815,8 @@ class Trainer(
# Run auto batch size scaling
if self.auto_scale_batch_size:
if isinstance(self.auto_scale_batch_size, bool):
self.auto_scale_batch_size = 'power'
self.scale_batch_size(model, mode=self.auto_scale_batch_size)
# Run learning rate finder:

View File

@ -88,3 +88,25 @@ def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
with pytest.raises(_UnkArgError):
parser.parse_args(cli_args)
# todo: add also testing for "gpus"
@pytest.mark.parametrize(['cli_args', 'expected'], [
pytest.param('--auto_lr_find --auto_scale_batch_size power',
{'auto_lr_find': True, 'auto_scale_batch_size': 'power', 'early_stop_callback': False}),
pytest.param('--auto_lr_find any_string --auto_scale_batch_size',
{'auto_lr_find': 'any_string', 'auto_scale_batch_size': True}),
pytest.param('--early_stop_callback',
{'auto_lr_find': False, 'early_stop_callback': True, 'auto_scale_batch_size': False}),
])
def test_argparse_args_parsing(cli_args, expected):
"""Test multi type argument with bool."""
cli_args = cli_args.split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parent_parser=parser)
args = Trainer.parse_argparser(parser)
for k, v in expected.items():
assert getattr(args, k) == v
assert Trainer.from_argparse_args(args)