fix fast_dev_run parsing from cli (#7240)
This commit is contained in:
parent
14b8dd479a
commit
b6706470c1
|
@ -399,6 +399,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed `parameters_to_ignore` not properly set to DDPWrapper ([#7239](https://github.com/PyTorchLightning/pytorch-lightning/pull/7239))
|
||||
|
||||
|
||||
- Fixed parsing of `fast_dev_run=True` with the built-in `ArgumentParser` ([#7240](https://github.com/PyTorchLightning/pytorch-lightning/pull/7240))
|
||||
|
||||
|
||||
|
||||
## [1.2.7] - 2021-04-06
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -17,7 +17,7 @@ from argparse import _ArgumentGroup, ArgumentParser, Namespace
|
|||
from contextlib import suppress
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_str
|
||||
from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_str, str_to_bool_or_int
|
||||
|
||||
|
||||
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
|
||||
|
@ -222,6 +222,8 @@ def add_argparse_args(
|
|||
# if the only arg type is bool
|
||||
if len(arg_types) == 1:
|
||||
use_type = str_to_bool
|
||||
elif int in arg_types:
|
||||
use_type = str_to_bool_or_int
|
||||
elif str in arg_types:
|
||||
use_type = str_to_bool_or_str
|
||||
else:
|
||||
|
|
|
@ -56,6 +56,27 @@ def str_to_bool(val: str) -> bool:
|
|||
raise ValueError(f'invalid truth value {val}')
|
||||
|
||||
|
||||
def str_to_bool_or_int(val: str) -> Union[bool, int, str]:
|
||||
"""Convert a string representation to truth of bool if possible, or otherwise try to convert it to an int.
|
||||
|
||||
>>> str_to_bool_or_int("FALSE")
|
||||
False
|
||||
>>> str_to_bool_or_int("1")
|
||||
True
|
||||
>>> str_to_bool_or_int("2")
|
||||
2
|
||||
>>> str_to_bool_or_int("abc")
|
||||
'abc'
|
||||
"""
|
||||
val = str_to_bool_or_str(val)
|
||||
if isinstance(val, bool):
|
||||
return val
|
||||
try:
|
||||
return int(val)
|
||||
except ValueError:
|
||||
return val
|
||||
|
||||
|
||||
def is_picklable(obj: object) -> bool:
|
||||
"""Tests if an object can be pickled"""
|
||||
|
||||
|
|
|
@ -175,6 +175,23 @@ def test_argparse_args_parsing(cli_args, expected):
|
|||
assert Trainer.from_argparse_args(args)
|
||||
|
||||
|
||||
@RunIf(min_python="3.7.0")
|
||||
@pytest.mark.parametrize('cli_args,expected', [
|
||||
('', False),
|
||||
('--fast_dev_run=0', False),
|
||||
('--fast_dev_run=True', True),
|
||||
('--fast_dev_run 2', 2),
|
||||
])
|
||||
def test_argparse_args_parsing_fast_dev_run(cli_args, expected):
|
||||
"""Test multi type argument with bool."""
|
||||
cli_args = cli_args.split(' ') if cli_args else []
|
||||
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
|
||||
parser = ArgumentParser(add_help=False)
|
||||
parser = Trainer.add_argparse_args(parent_parser=parser)
|
||||
args = Trainer.parse_argparser(parser)
|
||||
assert args.fast_dev_run is expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['cli_args', 'expected_parsed', 'expected_device_ids'], [
|
||||
pytest.param('', None, None),
|
||||
pytest.param('--gpus 1', 1, [0]),
|
||||
|
|
Loading…
Reference in New Issue