diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index 9edc71997f..6f91397bd0 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -17,7 +17,7 @@ from argparse import _ArgumentGroup, ArgumentParser, Namespace from contextlib import suppress from typing import Any, Dict, List, Tuple, Union -from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_str, str_to_bool_or_int +from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_int, str_to_bool_or_str def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): diff --git a/tests/trainer/test_trainer_cli.py b/tests/trainer/test_trainer_cli.py index a26883e897..d5153e6b82 100644 --- a/tests/trainer/test_trainer_cli.py +++ b/tests/trainer/test_trainer_cli.py @@ -176,12 +176,14 @@ def test_argparse_args_parsing(cli_args, expected): @RunIf(min_python="3.7.0") -@pytest.mark.parametrize('cli_args,expected', [ - ('', False), - ('--fast_dev_run=0', False), - ('--fast_dev_run=True', True), - ('--fast_dev_run 2', 2), -]) +@pytest.mark.parametrize( + 'cli_args,expected', [ + ('', False), + ('--fast_dev_run=0', False), + ('--fast_dev_run=True', True), + ('--fast_dev_run 2', 2), + ] +) def test_argparse_args_parsing_fast_dev_run(cli_args, expected): """Test multi type argument with bool.""" cli_args = cli_args.split(' ') if cli_args else [] diff --git a/tests/utilities/test_parsing.py b/tests/utilities/test_parsing.py index 57e49df2df..f50f44cf10 100644 --- a/tests/utilities/test_parsing.py +++ b/tests/utilities/test_parsing.py @@ -28,6 +28,7 @@ from pytorch_lightning.utilities.parsing import ( lightning_setattr, parse_class_init_keys, str_to_bool, + str_to_bool_or_int, str_to_bool_or_str, ) @@ -165,7 +166,7 @@ def test_lightning_setattr(tmpdir, model_cases): lightning_setattr(m, "this_attr_not_exist", None) -def test_str_to_bool_or_str(tmpdir): +def test_str_to_bool_or_str(): true_cases = ['y', 'yes', 't', 'true', 'on', '1'] false_cases = ['n', 'no', 'f', 'false', 'off', '0'] other_cases = ['yyeess', 'noooo', 'lightning'] @@ -180,7 +181,7 @@ def test_str_to_bool_or_str(tmpdir): assert str_to_bool_or_str(case) == case -def test_str_to_bool(tmpdir): +def test_str_to_bool(): true_cases = ['y', 'yes', 't', 'true', 'on', '1'] false_cases = ['n', 'no', 'f', 'false', 'off', '0'] other_cases = ['yyeess', 'noooo', 'lightning'] @@ -196,6 +197,14 @@ def test_str_to_bool(tmpdir): str_to_bool(case) +def test_str_to_bool_or_int(): + assert str_to_bool_or_int("0") is False + assert str_to_bool_or_int("1") is True + assert str_to_bool_or_int("true") is True + assert str_to_bool_or_int("2") == 2 + assert str_to_bool_or_int("abc") == "abc" + + def test_is_picklable(tmpdir): # See the full list of picklable types at # https://docs.python.org/3/library/pickle.html#pickle-picklable