Clean utilities/argparse and add missing tests (#6607)
This commit is contained in:
parent
870247ffe6
commit
853523ee64
|
@ -67,7 +67,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp
|
||||||
# Value has been passed as a flag => It is currently None, so we need to set it to True
|
# Value has been passed as a flag => It is currently None, so we need to set it to True
|
||||||
# We always set to True, regardless of the default value.
|
# We always set to True, regardless of the default value.
|
||||||
# Users must pass False directly, but when passing nothing True is assumed.
|
# Users must pass False directly, but when passing nothing True is assumed.
|
||||||
# i.e. the only way to disable somthing that defaults to True is to use the long form:
|
# i.e. the only way to disable something that defaults to True is to use the long form:
|
||||||
# "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None,
|
# "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None,
|
||||||
# which then becomes True here.
|
# which then becomes True here.
|
||||||
|
|
||||||
|
@ -242,9 +242,6 @@ def add_argparse_args(
|
||||||
if arg == 'track_grad_norm':
|
if arg == 'track_grad_norm':
|
||||||
use_type = float
|
use_type = float
|
||||||
|
|
||||||
if arg_default is inspect._empty:
|
|
||||||
arg_default = None
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
f'--{arg}',
|
f'--{arg}',
|
||||||
dest=arg,
|
dest=arg,
|
||||||
|
@ -291,10 +288,7 @@ def _gpus_allowed_type(x) -> Union[int, str]:
|
||||||
|
|
||||||
|
|
||||||
def _gpus_arg_default(x) -> Union[int, str]:
|
def _gpus_arg_default(x) -> Union[int, str]:
|
||||||
if ',' in x:
|
return _gpus_allowed_type(x)
|
||||||
return str(x)
|
|
||||||
else:
|
|
||||||
return int(x)
|
|
||||||
|
|
||||||
|
|
||||||
def _int_or_float_type(x) -> Union[int, float]:
|
def _int_or_float_type(x) -> Union[int, float]:
|
||||||
|
|
|
@ -1,17 +1,51 @@
|
||||||
import io
|
import io
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser, Namespace
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.utilities.argparse import (
|
from pytorch_lightning.utilities.argparse import (
|
||||||
add_argparse_args,
|
add_argparse_args,
|
||||||
|
from_argparse_args,
|
||||||
get_abbrev_qualified_cls_name,
|
get_abbrev_qualified_cls_name,
|
||||||
|
parse_argparser,
|
||||||
parse_args_from_docstring,
|
parse_args_from_docstring,
|
||||||
|
_gpus_arg_default,
|
||||||
|
_int_or_float_type
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ArgparseExample:
|
||||||
|
def __init__(self, a: int = 0, b: str = '', c: bool = False):
|
||||||
|
self.a = a
|
||||||
|
self.b = b
|
||||||
|
self.c = c
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_argparse_args():
|
||||||
|
args = Namespace(a=1, b='test', c=True, d='not valid')
|
||||||
|
my_instance = from_argparse_args(ArgparseExample, args)
|
||||||
|
assert my_instance.a == 1
|
||||||
|
assert my_instance.b == 'test'
|
||||||
|
assert my_instance.c
|
||||||
|
|
||||||
|
parser = ArgumentParser()
|
||||||
|
mock_trainer = MagicMock()
|
||||||
|
_ = from_argparse_args(mock_trainer, parser)
|
||||||
|
mock_trainer.parse_argparser.assert_called_once_with(parser)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_argparser():
|
||||||
|
args = Namespace(a=1, b='test', c=None, d='not valid')
|
||||||
|
new_args = parse_argparser(ArgparseExample, args)
|
||||||
|
assert new_args.a == 1
|
||||||
|
assert new_args.b == 'test'
|
||||||
|
assert new_args.c
|
||||||
|
assert new_args.d == 'not valid'
|
||||||
|
|
||||||
|
|
||||||
def test_parse_args_from_docstring_normal():
|
def test_parse_args_from_docstring_normal():
|
||||||
args_help = parse_args_from_docstring(
|
args_help = parse_args_from_docstring(
|
||||||
"""Constrain image dataset
|
"""Constrain image dataset
|
||||||
|
@ -168,3 +202,13 @@ def test_add_argparse_args_no_argument_group():
|
||||||
args = parser.parse_args(fake_argv)
|
args = parser.parse_args(fake_argv)
|
||||||
assert args.main_arg == "abc"
|
assert args.main_arg == "abc"
|
||||||
assert args.my_parameter == 2
|
assert args.my_parameter == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpus_arg_default():
|
||||||
|
assert _gpus_arg_default('1,2') == '1,2'
|
||||||
|
assert _gpus_arg_default('1') == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_int_or_float_type():
|
||||||
|
assert isinstance(_int_or_float_type('0.0'), float)
|
||||||
|
assert isinstance(_int_or_float_type('0'), int)
|
Loading…
Reference in New Issue