Removed from_argparse_args tests in test_cli.py (#14597)
This commit is contained in:
parent
3d540efe4f
commit
1680a76819
|
@ -14,9 +14,6 @@
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
import sys
|
|
||||||
from argparse import Namespace
|
|
||||||
from contextlib import contextmanager, ExitStack, redirect_stdout
|
from contextlib import contextmanager, ExitStack, redirect_stdout
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Union
|
||||||
|
@ -46,7 +43,6 @@ from pytorch_lightning.loggers.neptune import _NEPTUNE_AVAILABLE
|
||||||
from pytorch_lightning.loggers.wandb import _WANDB_AVAILABLE
|
from pytorch_lightning.loggers.wandb import _WANDB_AVAILABLE
|
||||||
from pytorch_lightning.strategies import DDPStrategy
|
from pytorch_lightning.strategies import DDPStrategy
|
||||||
from pytorch_lightning.trainer.states import TrainerFn
|
from pytorch_lightning.trainer.states import TrainerFn
|
||||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
|
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
|
||||||
from tests_pytorch.helpers.runif import RunIf
|
from tests_pytorch.helpers.runif import RunIf
|
||||||
|
@ -67,42 +63,6 @@ def mock_subclasses(baseclass, *subclasses):
|
||||||
yield None
|
yield None
|
||||||
|
|
||||||
|
|
||||||
@mock.patch("argparse.ArgumentParser.parse_args")
|
|
||||||
def test_default_args(mock_argparse):
|
|
||||||
"""Tests default argument parser for Trainer."""
|
|
||||||
mock_argparse.return_value = Namespace(**Trainer.default_attributes())
|
|
||||||
|
|
||||||
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
|
|
||||||
args = parser.parse_args([])
|
|
||||||
|
|
||||||
args.max_epochs = 5
|
|
||||||
trainer = Trainer.from_argparse_args(args)
|
|
||||||
|
|
||||||
assert isinstance(trainer, Trainer)
|
|
||||||
assert trainer.max_epochs == 5
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], []])
|
|
||||||
def test_add_argparse_args_redefined(cli_args):
|
|
||||||
"""Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness."""
|
|
||||||
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
|
|
||||||
parser.add_lightning_class_args(Trainer, None)
|
|
||||||
|
|
||||||
args = parser.parse_args(cli_args)
|
|
||||||
|
|
||||||
# make sure we can pickle args
|
|
||||||
pickle.dumps(args)
|
|
||||||
|
|
||||||
# Check few deprecated args are not in namespace:
|
|
||||||
for depr_name in ("gradient_clip", "nb_gpu_nodes", "max_nb_epochs"):
|
|
||||||
assert depr_name not in args
|
|
||||||
|
|
||||||
trainer = Trainer.from_argparse_args(args=args)
|
|
||||||
pickle.dumps(trainer)
|
|
||||||
|
|
||||||
assert isinstance(trainer, Trainer)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("cli_args", [["--callbacks=1", "--logger"], ["--foo", "--bar=1"]])
|
@pytest.mark.parametrize("cli_args", [["--callbacks=1", "--logger"], ["--foo", "--bar=1"]])
|
||||||
def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
|
def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
|
||||||
"""Asserts error raised in case of passing not default cli arguments."""
|
"""Asserts error raised in case of passing not default cli arguments."""
|
||||||
|
@ -122,121 +82,6 @@ def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
|
||||||
parser.parse_args(cli_args)
|
parser.parse_args(cli_args)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
["cli_args", "expected"],
|
|
||||||
[
|
|
||||||
("--auto_lr_find=True --auto_scale_batch_size=power", dict(auto_lr_find=True, auto_scale_batch_size="power")),
|
|
||||||
(
|
|
||||||
"--auto_lr_find any_string --auto_scale_batch_size ON",
|
|
||||||
dict(auto_lr_find="any_string", auto_scale_batch_size=True),
|
|
||||||
),
|
|
||||||
("--auto_lr_find=Yes --auto_scale_batch_size=On", dict(auto_lr_find=True, auto_scale_batch_size=True)),
|
|
||||||
("--auto_lr_find Off --auto_scale_batch_size No", dict(auto_lr_find=False, auto_scale_batch_size=False)),
|
|
||||||
("--auto_lr_find TRUE --auto_scale_batch_size FALSE", dict(auto_lr_find=True, auto_scale_batch_size=False)),
|
|
||||||
("--tpu_cores=8", dict(tpu_cores=8)),
|
|
||||||
("--tpu_cores=1,", dict(tpu_cores="1,")),
|
|
||||||
("--limit_train_batches=100", dict(limit_train_batches=100)),
|
|
||||||
("--limit_train_batches 0.8", dict(limit_train_batches=0.8)),
|
|
||||||
("--enable_model_summary FALSE", dict(enable_model_summary=False)),
|
|
||||||
(
|
|
||||||
"",
|
|
||||||
dict(
|
|
||||||
# These parameters are marked as Optional[...] in Trainer.__init__,
|
|
||||||
# with None as default. They should not be changed by the argparse
|
|
||||||
# interface.
|
|
||||||
min_steps=None,
|
|
||||||
accelerator=None,
|
|
||||||
profiler=None,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_parse_args_parsing(cli_args, expected):
|
|
||||||
"""Test parsing simple types and None optionals not modified."""
|
|
||||||
cli_args = cli_args.split(" ") if cli_args else []
|
|
||||||
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
|
||||||
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
|
|
||||||
parser.add_lightning_class_args(Trainer, None)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
for k, v in expected.items():
|
|
||||||
assert getattr(args, k) == v
|
|
||||||
if "tpu_cores" not in expected or _TPU_AVAILABLE:
|
|
||||||
assert Trainer.from_argparse_args(args)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
["cli_args", "expected", "instantiate"],
|
|
||||||
[
|
|
||||||
(["--gpus", "[0, 2]"], dict(gpus=[0, 2]), False),
|
|
||||||
(["--tpu_cores=[1,3]"], dict(tpu_cores=[1, 3]), False),
|
|
||||||
(['--accumulate_grad_batches={"5":3,"10":20}'], dict(accumulate_grad_batches={5: 3, 10: 20}), True),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_parse_args_parsing_complex_types(cli_args, expected, instantiate):
|
|
||||||
"""Test parsing complex types."""
|
|
||||||
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
|
||||||
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
|
|
||||||
parser.add_lightning_class_args(Trainer, None)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
for k, v in expected.items():
|
|
||||||
assert getattr(args, k) == v
|
|
||||||
if instantiate:
|
|
||||||
assert Trainer.from_argparse_args(args)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
["cli_args", "expected_gpu"],
|
|
||||||
[
|
|
||||||
("--accelerator gpu --devices 1", [0]),
|
|
||||||
("--accelerator gpu --devices 0,", [0]),
|
|
||||||
("--accelerator gpu --devices 1,", [1]),
|
|
||||||
("--accelerator gpu --devices 0,1", [0, 1]),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu):
|
|
||||||
"""Test parsing of gpus and instantiation of Trainer."""
|
|
||||||
monkeypatch.setattr("lightning_lite.utilities.device_parser.num_cuda_devices", lambda: 2)
|
|
||||||
monkeypatch.setattr("lightning_lite.utilities.device_parser.is_cuda_available", lambda: True)
|
|
||||||
cli_args = cli_args.split(" ") if cli_args else []
|
|
||||||
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
|
||||||
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
|
|
||||||
parser.add_lightning_class_args(Trainer, None)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
trainer = Trainer.from_argparse_args(args)
|
|
||||||
assert trainer.device_ids == expected_gpu
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
sys.version_info < (3, 7),
|
|
||||||
reason="signature inspection while mocking is not working in Python < 3.7 despite autospec",
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
["cli_args", "extra_args"],
|
|
||||||
[
|
|
||||||
({}, {}),
|
|
||||||
(dict(logger=False), {}),
|
|
||||||
(dict(logger=False), dict(logger=True)),
|
|
||||||
(dict(logger=False), dict(enable_checkpointing=True)),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_init_from_argparse_args(cli_args, extra_args):
|
|
||||||
unknown_args = dict(unknown_arg=0)
|
|
||||||
|
|
||||||
# unknown args in the argparser/namespace should be ignored
|
|
||||||
with mock.patch("pytorch_lightning.Trainer.__init__", autospec=True, return_value=None) as init:
|
|
||||||
trainer = Trainer.from_argparse_args(Namespace(**cli_args, **unknown_args), **extra_args)
|
|
||||||
expected = dict(cli_args)
|
|
||||||
expected.update(extra_args) # extra args should override any cli arg
|
|
||||||
init.assert_called_with(trainer, **expected)
|
|
||||||
|
|
||||||
# passing in unknown manual args should throw an error
|
|
||||||
with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'unknown_arg'"):
|
|
||||||
Trainer.from_argparse_args(Namespace(**cli_args), **extra_args, **unknown_args)
|
|
||||||
|
|
||||||
|
|
||||||
class Model(LightningModule):
|
class Model(LightningModule):
|
||||||
def __init__(self, model_param: int):
|
def __init__(self, model_param: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue