239 lines
7.2 KiB
Python
239 lines
7.2 KiB
Python
import io
|
|
from argparse import ArgumentParser, Namespace
|
|
from typing import Generic, List, TypeVar
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from pytorch_lightning import Trainer
|
|
from pytorch_lightning.utilities.argparse import (
|
|
_get_abbrev_qualified_cls_name,
|
|
_gpus_allowed_type,
|
|
_int_or_float_type,
|
|
_parse_args_from_docstring,
|
|
_precision_allowed_type,
|
|
add_argparse_args,
|
|
from_argparse_args,
|
|
parse_argparser,
|
|
)
|
|
|
|
|
|
class ArgparseExample:
|
|
def __init__(self, a: int = 0, b: str = "", c: bool = False):
|
|
self.a = a
|
|
self.b = b
|
|
self.c = c
|
|
|
|
|
|
def test_from_argparse_args():
|
|
args = Namespace(a=1, b="test", c=True, d="not valid")
|
|
my_instance = from_argparse_args(ArgparseExample, args)
|
|
assert my_instance.a == 1
|
|
assert my_instance.b == "test"
|
|
assert my_instance.c
|
|
|
|
parser = ArgumentParser()
|
|
mock_trainer = MagicMock()
|
|
_ = from_argparse_args(mock_trainer, parser)
|
|
mock_trainer.parse_argparser.assert_called_once_with(parser)
|
|
|
|
|
|
def test_parse_argparser():
|
|
args = Namespace(a=1, b="test", c=None, d="not valid")
|
|
new_args = parse_argparser(ArgparseExample, args)
|
|
assert new_args.a == 1
|
|
assert new_args.b == "test"
|
|
assert new_args.c
|
|
assert new_args.d == "not valid"
|
|
|
|
|
|
def test_parse_args_from_docstring_normal():
|
|
args_help = _parse_args_from_docstring(
|
|
"""Constrain image dataset
|
|
|
|
Args:
|
|
root: Root directory of dataset where ``MNIST/processed/training.pt``
|
|
and ``MNIST/processed/test.pt`` exist.
|
|
train: If ``True``, creates dataset from ``training.pt``,
|
|
otherwise from ``test.pt``.
|
|
normalize: mean and std deviation of the MNIST dataset.
|
|
download: If true, downloads the dataset from the internet and
|
|
puts it in root directory. If dataset is already downloaded, it is not
|
|
downloaded again.
|
|
num_samples: number of examples per selected class/digit
|
|
digits: list selected MNIST digits/classes
|
|
|
|
Examples:
|
|
>>> dataset = TrialMNIST(download=True)
|
|
>>> len(dataset)
|
|
300
|
|
>>> sorted(set([d.item() for d in dataset.targets]))
|
|
[0, 1, 2]
|
|
>>> torch.bincount(dataset.targets)
|
|
tensor([100, 100, 100])
|
|
"""
|
|
)
|
|
|
|
expected_args = ["root", "train", "normalize", "download", "num_samples", "digits"]
|
|
assert len(args_help.keys()) == len(expected_args)
|
|
assert all(x == y for x, y in zip(args_help.keys(), expected_args))
|
|
assert (
|
|
args_help["root"] == "Root directory of dataset where ``MNIST/processed/training.pt``"
|
|
" and ``MNIST/processed/test.pt`` exist."
|
|
)
|
|
assert args_help["normalize"] == "mean and std deviation of the MNIST dataset."
|
|
|
|
|
|
def test_parse_args_from_docstring_empty():
|
|
args_help = _parse_args_from_docstring(
|
|
"""Constrain image dataset
|
|
|
|
Args:
|
|
|
|
Returns:
|
|
|
|
Examples:
|
|
"""
|
|
)
|
|
assert len(args_help.keys()) == 0
|
|
|
|
|
|
def test_get_abbrev_qualified_cls_name():
|
|
assert _get_abbrev_qualified_cls_name(Trainer) == "pl.Trainer"
|
|
|
|
class NestedClass:
|
|
pass
|
|
|
|
assert not __name__.startswith("pytorch_lightning.")
|
|
expected_name = f"{__name__}.test_get_abbrev_qualified_cls_name.<locals>.NestedClass"
|
|
assert _get_abbrev_qualified_cls_name(NestedClass) == expected_name
|
|
|
|
|
|
class AddArgparseArgsExampleClass:
|
|
"""
|
|
Args:
|
|
my_parameter: A thing.
|
|
"""
|
|
|
|
def __init__(self, my_parameter: int = 0):
|
|
pass
|
|
|
|
@staticmethod
|
|
def get_deprecated_arg_names() -> List[str]:
|
|
return []
|
|
|
|
|
|
class AddArgparseArgsExampleClassViaInit:
|
|
def __init__(self, my_parameter: int = 0):
|
|
"""
|
|
Args:
|
|
my_parameter: A thing.
|
|
"""
|
|
pass
|
|
|
|
|
|
class AddArgparseArgsExampleClassNoDoc:
|
|
def __init__(self, my_parameter: int = 0):
|
|
pass
|
|
|
|
|
|
class AddArgparseArgsExampleClassGeneric:
|
|
T = TypeVar("T")
|
|
|
|
class SomeClass(Generic[T]):
|
|
pass
|
|
|
|
def __init__(self, invalid_class: SomeClass):
|
|
pass
|
|
|
|
|
|
def extract_help_text(parser):
|
|
help_str_buffer = io.StringIO()
|
|
parser.print_help(file=help_str_buffer)
|
|
help_str_buffer.seek(0)
|
|
return help_str_buffer.read()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
["cls", "name"],
|
|
[
|
|
[AddArgparseArgsExampleClass, "AddArgparseArgsExampleClass"],
|
|
[AddArgparseArgsExampleClassViaInit, "AddArgparseArgsExampleClassViaInit"],
|
|
[AddArgparseArgsExampleClassNoDoc, "AddArgparseArgsExampleClassNoDoc"],
|
|
],
|
|
)
|
|
def test_add_argparse_args(cls, name):
|
|
"""Tests that ``add_argparse_args`` handles argument groups correctly, and can be parsed."""
|
|
parser = ArgumentParser()
|
|
parser_main = parser.add_argument_group("main")
|
|
parser_main.add_argument("--main_arg", type=str, default="")
|
|
parser_old = parser # For testing.
|
|
parser = add_argparse_args(cls, parser)
|
|
assert parser is parser_old
|
|
|
|
# Check nominal argument groups.
|
|
help_text = extract_help_text(parser)
|
|
assert "main:" in help_text
|
|
assert "--main_arg" in help_text
|
|
assert f"{name}:" in help_text
|
|
assert "--my_parameter" in help_text
|
|
if cls is not AddArgparseArgsExampleClassNoDoc:
|
|
assert "A thing" in help_text
|
|
|
|
fake_argv = ["--main_arg=abc", "--my_parameter=2"]
|
|
args = parser.parse_args(fake_argv)
|
|
assert args.main_arg == "abc"
|
|
assert args.my_parameter == 2
|
|
|
|
|
|
def test_negative_add_argparse_args():
|
|
with pytest.raises(RuntimeError, match="Please only pass an ArgumentParser instance."):
|
|
parser = ArgumentParser()
|
|
add_argparse_args(AddArgparseArgsExampleClass, parser.add_argument_group("bad workflow"))
|
|
|
|
|
|
def test_add_argparse_args_no_argument_group():
|
|
"""Tests that ``add_argparse_args(..., use_argument_group=False)`` (old workflow) handles argument groups
|
|
correctly, and can be parsed."""
|
|
parser = ArgumentParser()
|
|
parser.add_argument("--main_arg", type=str, default="")
|
|
parser_old = parser # For testing.
|
|
parser = add_argparse_args(AddArgparseArgsExampleClass, parser, use_argument_group=False)
|
|
assert parser is not parser_old
|
|
|
|
# Check arguments.
|
|
help_text = extract_help_text(parser)
|
|
assert "--main_arg" in help_text
|
|
assert "--my_parameter" in help_text
|
|
assert "AddArgparseArgsExampleClass:" not in help_text
|
|
|
|
fake_argv = ["--main_arg=abc", "--my_parameter=2"]
|
|
args = parser.parse_args(fake_argv)
|
|
assert args.main_arg == "abc"
|
|
assert args.my_parameter == 2
|
|
|
|
|
|
def test_gpus_allowed_type():
|
|
assert _gpus_allowed_type("1,2") == "1,2"
|
|
assert _gpus_allowed_type("1") == 1
|
|
|
|
|
|
def test_int_or_float_type():
|
|
assert isinstance(_int_or_float_type("0.0"), float)
|
|
assert isinstance(_int_or_float_type("0"), int)
|
|
|
|
|
|
@pytest.mark.parametrize(["arg", "expected"], [["--precision=16", 16], ["--precision=bf16", "bf16"]])
|
|
def test_precision_parsed_correctly(arg, expected):
|
|
"""Test to ensure that the precision flag is passed correctly when adding argparse args."""
|
|
parser = ArgumentParser()
|
|
parser = Trainer.add_argparse_args(parser)
|
|
fake_argv = [arg]
|
|
args = parser.parse_args(fake_argv)
|
|
assert args.precision == expected
|
|
|
|
|
|
def test_precision_type():
|
|
assert _precision_allowed_type("bf16") == "bf16"
|
|
assert _precision_allowed_type("16") == 16
|