diff --git a/CHANGELOG.md b/CHANGELOG.md index 5bdaf6c95f..04840e386b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -399,6 +399,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `parameters_to_ignore` not properly set to DDPWrapper ([#7239](https://github.com/PyTorchLightning/pytorch-lightning/pull/7239)) +- Fixed parsing of `fast_dev_run=True` with the built-in `ArgumentParser` ([#7240](https://github.com/PyTorchLightning/pytorch-lightning/pull/7240)) + + + ## [1.2.7] - 2021-04-06 ### Fixed diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index dc99b923c6..9edc71997f 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -17,7 +17,7 @@ from argparse import _ArgumentGroup, ArgumentParser, Namespace from contextlib import suppress from typing import Any, Dict, List, Tuple, Union -from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_str +from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_str, str_to_bool_or_int def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): @@ -222,6 +222,8 @@ def add_argparse_args( # if the only arg type is bool if len(arg_types) == 1: use_type = str_to_bool + elif int in arg_types: + use_type = str_to_bool_or_int elif str in arg_types: use_type = str_to_bool_or_str else: diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index ae83ba15a9..37016a1293 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -56,6 +56,27 @@ def str_to_bool(val: str) -> bool: raise ValueError(f'invalid truth value {val}') +def str_to_bool_or_int(val: str) -> Union[bool, int, str]: + """Convert a string representation to truth of bool if possible, or otherwise try to convert it to an int. + + >>> str_to_bool_or_int("FALSE") + False + >>> str_to_bool_or_int("1") + True + >>> str_to_bool_or_int("2") + 2 + >>> str_to_bool_or_int("abc") + 'abc' + """ + val = str_to_bool_or_str(val) + if isinstance(val, bool): + return val + try: + return int(val) + except ValueError: + return val + + def is_picklable(obj: object) -> bool: """Tests if an object can be pickled""" diff --git a/tests/trainer/test_trainer_cli.py b/tests/trainer/test_trainer_cli.py index 7b91bcd941..a26883e897 100644 --- a/tests/trainer/test_trainer_cli.py +++ b/tests/trainer/test_trainer_cli.py @@ -175,6 +175,23 @@ def test_argparse_args_parsing(cli_args, expected): assert Trainer.from_argparse_args(args) +@RunIf(min_python="3.7.0") +@pytest.mark.parametrize('cli_args,expected', [ + ('', False), + ('--fast_dev_run=0', False), + ('--fast_dev_run=True', True), + ('--fast_dev_run 2', 2), +]) +def test_argparse_args_parsing_fast_dev_run(cli_args, expected): + """Test multi type argument with bool.""" + cli_args = cli_args.split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + parser = ArgumentParser(add_help=False) + parser = Trainer.add_argparse_args(parent_parser=parser) + args = Trainer.parse_argparser(parser) + assert args.fast_dev_run is expected + + @pytest.mark.parametrize(['cli_args', 'expected_parsed', 'expected_device_ids'], [ pytest.param('', None, None), pytest.param('--gpus 1', 1, [0]),