fix fast_dev_run parsing from cli (#7240)

This commit is contained in:
Adrian Wälchli 2021-04-29 21:46:20 +02:00 committed by GitHub
parent 14b8dd479a
commit b6706470c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 45 additions and 1 deletions

View File

@ -399,6 +399,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `parameters_to_ignore` not properly set to DDPWrapper ([#7239](https://github.com/PyTorchLightning/pytorch-lightning/pull/7239))
- Fixed parsing of `fast_dev_run=True` with the built-in `ArgumentParser` ([#7240](https://github.com/PyTorchLightning/pytorch-lightning/pull/7240))
## [1.2.7] - 2021-04-06
### Fixed

View File

@ -17,7 +17,7 @@ from argparse import _ArgumentGroup, ArgumentParser, Namespace
from contextlib import suppress
from typing import Any, Dict, List, Tuple, Union
from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_str
from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_str, str_to_bool_or_int
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
@ -222,6 +222,8 @@ def add_argparse_args(
# if the only arg type is bool
if len(arg_types) == 1:
use_type = str_to_bool
elif int in arg_types:
use_type = str_to_bool_or_int
elif str in arg_types:
use_type = str_to_bool_or_str
else:

View File

@ -56,6 +56,27 @@ def str_to_bool(val: str) -> bool:
raise ValueError(f'invalid truth value {val}')
def str_to_bool_or_int(val: str) -> Union[bool, int, str]:
"""Convert a string representation to truth of bool if possible, or otherwise try to convert it to an int.
>>> str_to_bool_or_int("FALSE")
False
>>> str_to_bool_or_int("1")
True
>>> str_to_bool_or_int("2")
2
>>> str_to_bool_or_int("abc")
'abc'
"""
val = str_to_bool_or_str(val)
if isinstance(val, bool):
return val
try:
return int(val)
except ValueError:
return val
def is_picklable(obj: object) -> bool:
"""Tests if an object can be pickled"""

View File

@ -175,6 +175,23 @@ def test_argparse_args_parsing(cli_args, expected):
assert Trainer.from_argparse_args(args)
@RunIf(min_python="3.7.0")
@pytest.mark.parametrize('cli_args,expected', [
('', False),
('--fast_dev_run=0', False),
('--fast_dev_run=True', True),
('--fast_dev_run 2', 2),
])
def test_argparse_args_parsing_fast_dev_run(cli_args, expected):
"""Test multi type argument with bool."""
cli_args = cli_args.split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parent_parser=parser)
args = Trainer.parse_argparser(parser)
assert args.fast_dev_run is expected
@pytest.mark.parametrize(['cli_args', 'expected_parsed', 'expected_device_ids'], [
pytest.param('', None, None),
pytest.param('--gpus 1', 1, [0]),