From b3ebfec863df8513f42e7211a29f857139e8ede4 Mon Sep 17 00:00:00 2001 From: Espen Haugsdal Date: Thu, 9 Jul 2020 13:10:30 +0200 Subject: [PATCH] Fix argparse default value bug (#2526) * Add failing test for bug * Fix bug --- pytorch_lightning/trainer/trainer.py | 28 ++++++++++++++++++++++++---- tests/trainer/test_trainer_cli.py | 16 +++++++++++++++- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1b3e053387..eec2175291 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -793,12 +793,32 @@ class Trainer( else: return int(x) - @staticmethod - def parse_argparser(arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: + @classmethod + def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: """Parse CLI arguments, required for custom bool types.""" args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser - args = {k: True if v is None else v for k, v in vars(args).items()} - return Namespace(**args) + + types_default = { + arg: (arg_types, arg_default) for arg, arg_types, arg_default in cls.get_init_arguments_and_types() + } + + modified_args = {} + for k, v in vars(args).items(): + if k in types_default and v is None: + # We need to figure out if the None is due to using nargs="?" or if it comes from the default value + arg_types, arg_default = types_default[k] + if bool in arg_types and isinstance(arg_default, bool): + # Value has been passed as a flag => It is currently None, so we need to set it to True + # We always set to True, regardless of the default value. + # Users must pass False directly, but when passing nothing True is assumed. + # i.e. the only way to disable somthing that defaults to True is to use the long form: + # "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None, + # which then becomes True here. + + v = True + + modified_args[k] = v + return Namespace(**modified_args) @classmethod def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer': diff --git a/tests/trainer/test_trainer_cli.py b/tests/trainer/test_trainer_cli.py index 51bbc96bd4..fd381c2513 100644 --- a/tests/trainer/test_trainer_cli.py +++ b/tests/trainer/test_trainer_cli.py @@ -102,7 +102,21 @@ def test_add_argparse_args_redefined_error(cli_args, monkeypatch): pytest.param('--tpu_cores=8', {'tpu_cores': 8}), pytest.param("--tpu_cores=1,", - {'tpu_cores': '1,'}) + {'tpu_cores': '1,'}), + pytest.param( + "", + { + # These parameters are marked as Optional[...] in Trainer.__init__, with None as default. + # They should not be changed by the argparse interface. + "min_steps": None, + "max_steps": None, + "log_gpu_memory": None, + "distributed_backend": None, + "weights_save_path": None, + "truncated_bptt_steps": None, + "resume_from_checkpoint": None, + "profiler": None, + }), ]) def test_argparse_args_parsing(cli_args, expected): """Test multi type argument with bool."""