lightning/tests/utilities/test_argparse.py

239 lines
7.2 KiB
Python
Raw Normal View History

import io
from argparse import ArgumentParser, Namespace
from typing import Generic, List, TypeVar
from unittest.mock import MagicMock
import pytest
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.argparse import (
2021-04-22 00:13:00 +00:00
_get_abbrev_qualified_cls_name,
_gpus_allowed_type,
_int_or_float_type,
2021-04-22 00:13:00 +00:00
_parse_args_from_docstring,
_precision_allowed_type,
add_argparse_args,
from_argparse_args,
parse_argparser,
)
class ArgparseExample:
def __init__(self, a: int = 0, b: str = "", c: bool = False):
self.a = a
self.b = b
self.c = c
def test_from_argparse_args():
args = Namespace(a=1, b="test", c=True, d="not valid")
my_instance = from_argparse_args(ArgparseExample, args)
assert my_instance.a == 1
assert my_instance.b == "test"
assert my_instance.c
parser = ArgumentParser()
mock_trainer = MagicMock()
_ = from_argparse_args(mock_trainer, parser)
mock_trainer.parse_argparser.assert_called_once_with(parser)
def test_parse_argparser():
args = Namespace(a=1, b="test", c=None, d="not valid")
new_args = parse_argparser(ArgparseExample, args)
assert new_args.a == 1
assert new_args.b == "test"
assert new_args.c
assert new_args.d == "not valid"
def test_parse_args_from_docstring_normal():
2021-04-22 00:13:00 +00:00
args_help = _parse_args_from_docstring(
"""Constrain image dataset
Args:
root: Root directory of dataset where ``MNIST/processed/training.pt``
and ``MNIST/processed/test.pt`` exist.
train: If ``True``, creates dataset from ``training.pt``,
otherwise from ``test.pt``.
normalize: mean and std deviation of the MNIST dataset.
download: If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
num_samples: number of examples per selected class/digit
digits: list selected MNIST digits/classes
Examples:
>>> dataset = TrialMNIST(download=True)
>>> len(dataset)
300
>>> sorted(set([d.item() for d in dataset.targets]))
[0, 1, 2]
>>> torch.bincount(dataset.targets)
tensor([100, 100, 100])
"""
)
expected_args = ["root", "train", "normalize", "download", "num_samples", "digits"]
assert len(args_help.keys()) == len(expected_args)
assert all(x == y for x, y in zip(args_help.keys(), expected_args))
assert (
args_help["root"] == "Root directory of dataset where ``MNIST/processed/training.pt``"
" and ``MNIST/processed/test.pt`` exist."
)
assert args_help["normalize"] == "mean and std deviation of the MNIST dataset."
def test_parse_args_from_docstring_empty():
2021-04-22 00:13:00 +00:00
args_help = _parse_args_from_docstring(
"""Constrain image dataset
Args:
Returns:
Examples:
"""
)
assert len(args_help.keys()) == 0
def test_get_abbrev_qualified_cls_name():
2021-04-22 00:13:00 +00:00
assert _get_abbrev_qualified_cls_name(Trainer) == "pl.Trainer"
class NestedClass:
pass
assert not __name__.startswith("pytorch_lightning.")
expected_name = f"{__name__}.test_get_abbrev_qualified_cls_name.<locals>.NestedClass"
2021-04-22 00:13:00 +00:00
assert _get_abbrev_qualified_cls_name(NestedClass) == expected_name
class AddArgparseArgsExampleClass:
"""
Args:
my_parameter: A thing.
"""
def __init__(self, my_parameter: int = 0):
pass
@staticmethod
def get_deprecated_arg_names() -> List[str]:
return []
class AddArgparseArgsExampleClassViaInit:
def __init__(self, my_parameter: int = 0):
"""
Args:
my_parameter: A thing.
"""
pass
class AddArgparseArgsExampleClassNoDoc:
def __init__(self, my_parameter: int = 0):
pass
class AddArgparseArgsExampleClassGeneric:
T = TypeVar("T")
class SomeClass(Generic[T]):
pass
def __init__(self, invalid_class: SomeClass):
pass
def extract_help_text(parser):
help_str_buffer = io.StringIO()
parser.print_help(file=help_str_buffer)
help_str_buffer.seek(0)
return help_str_buffer.read()
@pytest.mark.parametrize(
["cls", "name"],
[
[AddArgparseArgsExampleClass, "AddArgparseArgsExampleClass"],
[AddArgparseArgsExampleClassViaInit, "AddArgparseArgsExampleClassViaInit"],
[AddArgparseArgsExampleClassNoDoc, "AddArgparseArgsExampleClassNoDoc"],
],
)
def test_add_argparse_args(cls, name):
"""Tests that ``add_argparse_args`` handles argument groups correctly, and can be parsed."""
parser = ArgumentParser()
parser_main = parser.add_argument_group("main")
parser_main.add_argument("--main_arg", type=str, default="")
parser_old = parser # For testing.
parser = add_argparse_args(cls, parser)
assert parser is parser_old
# Check nominal argument groups.
help_text = extract_help_text(parser)
assert "main:" in help_text
assert "--main_arg" in help_text
assert f"{name}:" in help_text
assert "--my_parameter" in help_text
if cls is not AddArgparseArgsExampleClassNoDoc:
assert "A thing" in help_text
fake_argv = ["--main_arg=abc", "--my_parameter=2"]
args = parser.parse_args(fake_argv)
assert args.main_arg == "abc"
assert args.my_parameter == 2
def test_negative_add_argparse_args():
with pytest.raises(RuntimeError, match="Please only pass an ArgumentParser instance."):
parser = ArgumentParser()
add_argparse_args(AddArgparseArgsExampleClass, parser.add_argument_group("bad workflow"))
def test_add_argparse_args_no_argument_group():
"""Tests that ``add_argparse_args(..., use_argument_group=False)`` (old workflow) handles argument groups
correctly, and can be parsed."""
parser = ArgumentParser()
parser.add_argument("--main_arg", type=str, default="")
parser_old = parser # For testing.
parser = add_argparse_args(AddArgparseArgsExampleClass, parser, use_argument_group=False)
assert parser is not parser_old
# Check arguments.
help_text = extract_help_text(parser)
assert "--main_arg" in help_text
assert "--my_parameter" in help_text
assert "AddArgparseArgsExampleClass:" not in help_text
fake_argv = ["--main_arg=abc", "--my_parameter=2"]
args = parser.parse_args(fake_argv)
assert args.main_arg == "abc"
assert args.my_parameter == 2
def test_gpus_allowed_type():
assert _gpus_allowed_type("1,2") == "1,2"
assert _gpus_allowed_type("1") == 1
def test_int_or_float_type():
assert isinstance(_int_or_float_type("0.0"), float)
assert isinstance(_int_or_float_type("0"), int)
@pytest.mark.parametrize(["arg", "expected"], [["--precision=16", 16], ["--precision=bf16", "bf16"]])
def test_precision_parsed_correctly(arg, expected):
"""Test to ensure that the precision flag is passed correctly when adding argparse args."""
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
fake_argv = [arg]
args = parser.parse_args(fake_argv)
assert args.precision == expected
def test_precision_type():
assert _precision_allowed_type("bf16") == "bf16"
assert _precision_allowed_type("16") == 16