Clean utilities/argparse and add missing tests (#6607)

This commit is contained in:
Ethan Harris 2021-03-22 08:53:51 +00:00 committed by GitHub
parent 870247ffe6
commit 853523ee64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 9 deletions

View File

@ -67,7 +67,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp
# Value has been passed as a flag => It is currently None, so we need to set it to True # Value has been passed as a flag => It is currently None, so we need to set it to True
# We always set to True, regardless of the default value. # We always set to True, regardless of the default value.
# Users must pass False directly, but when passing nothing True is assumed. # Users must pass False directly, but when passing nothing True is assumed.
# i.e. the only way to disable somthing that defaults to True is to use the long form: # i.e. the only way to disable something that defaults to True is to use the long form:
# "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None, # "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None,
# which then becomes True here. # which then becomes True here.
@ -242,9 +242,6 @@ def add_argparse_args(
if arg == 'track_grad_norm': if arg == 'track_grad_norm':
use_type = float use_type = float
if arg_default is inspect._empty:
arg_default = None
parser.add_argument( parser.add_argument(
f'--{arg}', f'--{arg}',
dest=arg, dest=arg,
@ -291,10 +288,7 @@ def _gpus_allowed_type(x) -> Union[int, str]:
def _gpus_arg_default(x) -> Union[int, str]: def _gpus_arg_default(x) -> Union[int, str]:
if ',' in x: return _gpus_allowed_type(x)
return str(x)
else:
return int(x)
def _int_or_float_type(x) -> Union[int, float]: def _int_or_float_type(x) -> Union[int, float]:

View File

@ -1,17 +1,51 @@
import io import io
from argparse import ArgumentParser from argparse import ArgumentParser, Namespace
from typing import List from typing import List
from unittest.mock import MagicMock
import pytest import pytest
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.utilities.argparse import ( from pytorch_lightning.utilities.argparse import (
add_argparse_args, add_argparse_args,
from_argparse_args,
get_abbrev_qualified_cls_name, get_abbrev_qualified_cls_name,
parse_argparser,
parse_args_from_docstring, parse_args_from_docstring,
_gpus_arg_default,
_int_or_float_type
) )
class ArgparseExample:
def __init__(self, a: int = 0, b: str = '', c: bool = False):
self.a = a
self.b = b
self.c = c
def test_from_argparse_args():
args = Namespace(a=1, b='test', c=True, d='not valid')
my_instance = from_argparse_args(ArgparseExample, args)
assert my_instance.a == 1
assert my_instance.b == 'test'
assert my_instance.c
parser = ArgumentParser()
mock_trainer = MagicMock()
_ = from_argparse_args(mock_trainer, parser)
mock_trainer.parse_argparser.assert_called_once_with(parser)
def test_parse_argparser():
args = Namespace(a=1, b='test', c=None, d='not valid')
new_args = parse_argparser(ArgparseExample, args)
assert new_args.a == 1
assert new_args.b == 'test'
assert new_args.c
assert new_args.d == 'not valid'
def test_parse_args_from_docstring_normal(): def test_parse_args_from_docstring_normal():
args_help = parse_args_from_docstring( args_help = parse_args_from_docstring(
"""Constrain image dataset """Constrain image dataset
@ -168,3 +202,13 @@ def test_add_argparse_args_no_argument_group():
args = parser.parse_args(fake_argv) args = parser.parse_args(fake_argv)
assert args.main_arg == "abc" assert args.main_arg == "abc"
assert args.my_parameter == 2 assert args.my_parameter == 2
def test_gpus_arg_default():
assert _gpus_arg_default('1,2') == '1,2'
assert _gpus_arg_default('1') == 1
def test_int_or_float_type():
assert isinstance(_int_or_float_type('0.0'), float)
assert isinstance(_int_or_float_type('0'), int)