Fix argparse default value bug (#2526)

* Add failing test for bug

* Fix bug
This commit is contained in:
Espen Haugsdal 2020-07-09 13:10:30 +02:00 committed by GitHub
parent 9a367a899a
commit b3ebfec863
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 5 deletions

View File

@ -793,12 +793,32 @@ class Trainer(
else:
return int(x)
@staticmethod
def parse_argparser(arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
@classmethod
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
"""Parse CLI arguments, required for custom bool types."""
args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser
args = {k: True if v is None else v for k, v in vars(args).items()}
return Namespace(**args)
types_default = {
arg: (arg_types, arg_default) for arg, arg_types, arg_default in cls.get_init_arguments_and_types()
}
modified_args = {}
for k, v in vars(args).items():
if k in types_default and v is None:
# We need to figure out if the None is due to using nargs="?" or if it comes from the default value
arg_types, arg_default = types_default[k]
if bool in arg_types and isinstance(arg_default, bool):
# Value has been passed as a flag => It is currently None, so we need to set it to True
# We always set to True, regardless of the default value.
# Users must pass False directly, but when passing nothing True is assumed.
# i.e. the only way to disable somthing that defaults to True is to use the long form:
# "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None,
# which then becomes True here.
v = True
modified_args[k] = v
return Namespace(**modified_args)
@classmethod
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer':

View File

@ -102,7 +102,21 @@ def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
pytest.param('--tpu_cores=8',
{'tpu_cores': 8}),
pytest.param("--tpu_cores=1,",
{'tpu_cores': '1,'})
{'tpu_cores': '1,'}),
pytest.param(
"",
{
# These parameters are marked as Optional[...] in Trainer.__init__, with None as default.
# They should not be changed by the argparse interface.
"min_steps": None,
"max_steps": None,
"log_gpu_memory": None,
"distributed_backend": None,
"weights_save_path": None,
"truncated_bptt_steps": None,
"resume_from_checkpoint": None,
"profiler": None,
}),
])
def test_argparse_args_parsing(cli_args, expected):
"""Test multi type argument with bool."""