Clean utilities/argparse and add missing tests (#6607)
This commit is contained in:
parent
870247ffe6
commit
853523ee64
|
@ -67,7 +67,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp
|
|||
# Value has been passed as a flag => It is currently None, so we need to set it to True
|
||||
# We always set to True, regardless of the default value.
|
||||
# Users must pass False directly, but when passing nothing True is assumed.
|
||||
# i.e. the only way to disable somthing that defaults to True is to use the long form:
|
||||
# i.e. the only way to disable something that defaults to True is to use the long form:
|
||||
# "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None,
|
||||
# which then becomes True here.
|
||||
|
||||
|
@ -242,9 +242,6 @@ def add_argparse_args(
|
|||
if arg == 'track_grad_norm':
|
||||
use_type = float
|
||||
|
||||
if arg_default is inspect._empty:
|
||||
arg_default = None
|
||||
|
||||
parser.add_argument(
|
||||
f'--{arg}',
|
||||
dest=arg,
|
||||
|
@ -291,10 +288,7 @@ def _gpus_allowed_type(x) -> Union[int, str]:
|
|||
|
||||
|
||||
def _gpus_arg_default(x) -> Union[int, str]:
|
||||
if ',' in x:
|
||||
return str(x)
|
||||
else:
|
||||
return int(x)
|
||||
return _gpus_allowed_type(x)
|
||||
|
||||
|
||||
def _int_or_float_type(x) -> Union[int, float]:
|
||||
|
|
|
@ -1,17 +1,51 @@
|
|||
import io
|
||||
from argparse import ArgumentParser
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities.argparse import (
|
||||
add_argparse_args,
|
||||
from_argparse_args,
|
||||
get_abbrev_qualified_cls_name,
|
||||
parse_argparser,
|
||||
parse_args_from_docstring,
|
||||
_gpus_arg_default,
|
||||
_int_or_float_type
|
||||
)
|
||||
|
||||
|
||||
class ArgparseExample:
|
||||
def __init__(self, a: int = 0, b: str = '', c: bool = False):
|
||||
self.a = a
|
||||
self.b = b
|
||||
self.c = c
|
||||
|
||||
|
||||
def test_from_argparse_args():
|
||||
args = Namespace(a=1, b='test', c=True, d='not valid')
|
||||
my_instance = from_argparse_args(ArgparseExample, args)
|
||||
assert my_instance.a == 1
|
||||
assert my_instance.b == 'test'
|
||||
assert my_instance.c
|
||||
|
||||
parser = ArgumentParser()
|
||||
mock_trainer = MagicMock()
|
||||
_ = from_argparse_args(mock_trainer, parser)
|
||||
mock_trainer.parse_argparser.assert_called_once_with(parser)
|
||||
|
||||
|
||||
def test_parse_argparser():
|
||||
args = Namespace(a=1, b='test', c=None, d='not valid')
|
||||
new_args = parse_argparser(ArgparseExample, args)
|
||||
assert new_args.a == 1
|
||||
assert new_args.b == 'test'
|
||||
assert new_args.c
|
||||
assert new_args.d == 'not valid'
|
||||
|
||||
|
||||
def test_parse_args_from_docstring_normal():
|
||||
args_help = parse_args_from_docstring(
|
||||
"""Constrain image dataset
|
||||
|
@ -168,3 +202,13 @@ def test_add_argparse_args_no_argument_group():
|
|||
args = parser.parse_args(fake_argv)
|
||||
assert args.main_arg == "abc"
|
||||
assert args.my_parameter == 2
|
||||
|
||||
|
||||
def test_gpus_arg_default():
|
||||
assert _gpus_arg_default('1,2') == '1,2'
|
||||
assert _gpus_arg_default('1') == 1
|
||||
|
||||
|
||||
def test_int_or_float_type():
|
||||
assert isinstance(_int_or_float_type('0.0'), float)
|
||||
assert isinstance(_int_or_float_type('0'), int)
|
Loading…
Reference in New Issue