Clean utilities/argparse and add missing tests (#6607)

This commit is contained in:
Ethan Harris 2021-03-22 08:53:51 +00:00 committed by GitHub
parent 870247ffe6
commit 853523ee64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 9 deletions

View File

@ -67,7 +67,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp
# Value has been passed as a flag => It is currently None, so we need to set it to True
# We always set to True, regardless of the default value.
# Users must pass False directly, but when passing nothing True is assumed.
# i.e. the only way to disable somthing that defaults to True is to use the long form:
# i.e. the only way to disable something that defaults to True is to use the long form:
# "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None,
# which then becomes True here.
@ -242,9 +242,6 @@ def add_argparse_args(
if arg == 'track_grad_norm':
use_type = float
if arg_default is inspect._empty:
arg_default = None
parser.add_argument(
f'--{arg}',
dest=arg,
@ -291,10 +288,7 @@ def _gpus_allowed_type(x) -> Union[int, str]:
def _gpus_arg_default(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)
return _gpus_allowed_type(x)
def _int_or_float_type(x) -> Union[int, float]:

View File

@ -1,17 +1,51 @@
import io
from argparse import ArgumentParser
from argparse import ArgumentParser, Namespace
from typing import List
from unittest.mock import MagicMock
import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.argparse import (
add_argparse_args,
from_argparse_args,
get_abbrev_qualified_cls_name,
parse_argparser,
parse_args_from_docstring,
_gpus_arg_default,
_int_or_float_type
)
class ArgparseExample:
def __init__(self, a: int = 0, b: str = '', c: bool = False):
self.a = a
self.b = b
self.c = c
def test_from_argparse_args():
args = Namespace(a=1, b='test', c=True, d='not valid')
my_instance = from_argparse_args(ArgparseExample, args)
assert my_instance.a == 1
assert my_instance.b == 'test'
assert my_instance.c
parser = ArgumentParser()
mock_trainer = MagicMock()
_ = from_argparse_args(mock_trainer, parser)
mock_trainer.parse_argparser.assert_called_once_with(parser)
def test_parse_argparser():
args = Namespace(a=1, b='test', c=None, d='not valid')
new_args = parse_argparser(ArgparseExample, args)
assert new_args.a == 1
assert new_args.b == 'test'
assert new_args.c
assert new_args.d == 'not valid'
def test_parse_args_from_docstring_normal():
args_help = parse_args_from_docstring(
"""Constrain image dataset
@ -168,3 +202,13 @@ def test_add_argparse_args_no_argument_group():
args = parser.parse_args(fake_argv)
assert args.main_arg == "abc"
assert args.my_parameter == 2
def test_gpus_arg_default():
assert _gpus_arg_default('1,2') == '1,2'
assert _gpus_arg_default('1') == 1
def test_int_or_float_type():
assert isinstance(_int_or_float_type('0.0'), float)
assert isinstance(_int_or_float_type('0'), int)