Fix argparse default value bug (#2526)
* Add failing test for bug * Fix bug
This commit is contained in:
parent
9a367a899a
commit
b3ebfec863
|
@ -793,12 +793,32 @@ class Trainer(
|
|||
else:
|
||||
return int(x)
|
||||
|
||||
@staticmethod
|
||||
def parse_argparser(arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
|
||||
@classmethod
|
||||
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
|
||||
"""Parse CLI arguments, required for custom bool types."""
|
||||
args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser
|
||||
args = {k: True if v is None else v for k, v in vars(args).items()}
|
||||
return Namespace(**args)
|
||||
|
||||
types_default = {
|
||||
arg: (arg_types, arg_default) for arg, arg_types, arg_default in cls.get_init_arguments_and_types()
|
||||
}
|
||||
|
||||
modified_args = {}
|
||||
for k, v in vars(args).items():
|
||||
if k in types_default and v is None:
|
||||
# We need to figure out if the None is due to using nargs="?" or if it comes from the default value
|
||||
arg_types, arg_default = types_default[k]
|
||||
if bool in arg_types and isinstance(arg_default, bool):
|
||||
# Value has been passed as a flag => It is currently None, so we need to set it to True
|
||||
# We always set to True, regardless of the default value.
|
||||
# Users must pass False directly, but when passing nothing True is assumed.
|
||||
# i.e. the only way to disable somthing that defaults to True is to use the long form:
|
||||
# "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None,
|
||||
# which then becomes True here.
|
||||
|
||||
v = True
|
||||
|
||||
modified_args[k] = v
|
||||
return Namespace(**modified_args)
|
||||
|
||||
@classmethod
|
||||
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer':
|
||||
|
|
|
@ -102,7 +102,21 @@ def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
|
|||
pytest.param('--tpu_cores=8',
|
||||
{'tpu_cores': 8}),
|
||||
pytest.param("--tpu_cores=1,",
|
||||
{'tpu_cores': '1,'})
|
||||
{'tpu_cores': '1,'}),
|
||||
pytest.param(
|
||||
"",
|
||||
{
|
||||
# These parameters are marked as Optional[...] in Trainer.__init__, with None as default.
|
||||
# They should not be changed by the argparse interface.
|
||||
"min_steps": None,
|
||||
"max_steps": None,
|
||||
"log_gpu_memory": None,
|
||||
"distributed_backend": None,
|
||||
"weights_save_path": None,
|
||||
"truncated_bptt_steps": None,
|
||||
"resume_from_checkpoint": None,
|
||||
"profiler": None,
|
||||
}),
|
||||
])
|
||||
def test_argparse_args_parsing(cli_args, expected):
|
||||
"""Test multi type argument with bool."""
|
||||
|
|
Loading…
Reference in New Issue