lightning/tests/utilities/test_cli.py

1369 lines
52 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
import pickle
import sys
from argparse import Namespace
from contextlib import redirect_stdout
from io import StringIO
from typing import List, Optional, Union
from unittest import mock
from unittest.mock import ANY
import pytest
import torch
import yaml
from packaging import version
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.cli import (
CALLBACK_REGISTRY,
DATAMODULE_REGISTRY,
instantiate_class,
LightningArgumentParser,
LightningCLI,
LR_SCHEDULER_REGISTRY,
MODEL_REGISTRY,
OPTIMIZER_REGISTRY,
SaveConfigCallback,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
from tests.helpers.utils import no_warning_call
torchvision_version = version.parse("0")
if _TORCHVISION_AVAILABLE:
torchvision_version = version.parse(__import__("torchvision").__version__)
@mock.patch("argparse.ArgumentParser.parse_args")
def test_default_args(mock_argparse, tmpdir):
"""Tests default argument parser for Trainer."""
mock_argparse.return_value = Namespace(**Trainer.default_attributes())
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
args = parser.parse_args([])
args.max_epochs = 5
trainer = Trainer.from_argparse_args(args)
assert isinstance(trainer, Trainer)
assert trainer.max_epochs == 5
@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], ["--weights_save_path=./"], []])
def test_add_argparse_args_redefined(cli_args):
"""Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness."""
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser.add_lightning_class_args(Trainer, None)
args = parser.parse_args(cli_args)
# make sure we can pickle args
pickle.dumps(args)
# Check few deprecated args are not in namespace:
for depr_name in ("gradient_clip", "nb_gpu_nodes", "max_nb_epochs"):
assert depr_name not in args
trainer = Trainer.from_argparse_args(args=args)
pickle.dumps(trainer)
assert isinstance(trainer, Trainer)
@pytest.mark.parametrize("cli_args", [["--callbacks=1", "--logger"], ["--foo", "--bar=1"]])
def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
"""Asserts error raised in case of passing not default cli arguments."""
class _UnkArgError(Exception):
pass
def _raise():
raise _UnkArgError
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser.add_lightning_class_args(Trainer, None)
monkeypatch.setattr(parser, "exit", lambda *args: _raise(), raising=True)
with pytest.raises(_UnkArgError):
parser.parse_args(cli_args)
@pytest.mark.parametrize(
["cli_args", "expected"],
[
("--auto_lr_find=True --auto_scale_batch_size=power", dict(auto_lr_find=True, auto_scale_batch_size="power")),
(
"--auto_lr_find any_string --auto_scale_batch_size ON",
dict(auto_lr_find="any_string", auto_scale_batch_size=True),
),
("--auto_lr_find=Yes --auto_scale_batch_size=On", dict(auto_lr_find=True, auto_scale_batch_size=True)),
("--auto_lr_find Off --auto_scale_batch_size No", dict(auto_lr_find=False, auto_scale_batch_size=False)),
("--auto_lr_find TRUE --auto_scale_batch_size FALSE", dict(auto_lr_find=True, auto_scale_batch_size=False)),
("--tpu_cores=8", dict(tpu_cores=8)),
("--tpu_cores=1,", dict(tpu_cores="1,")),
("--limit_train_batches=100", dict(limit_train_batches=100)),
("--limit_train_batches 0.8", dict(limit_train_batches=0.8)),
("--enable_model_summary FALSE", dict(enable_model_summary=False)),
(
"",
dict(
# These parameters are marked as Optional[...] in Trainer.__init__,
# with None as default. They should not be changed by the argparse
# interface.
min_steps=None,
accelerator=None,
weights_save_path=None,
profiler=None,
),
),
],
)
def test_parse_args_parsing(cli_args, expected):
"""Test parsing simple types and None optionals not modified."""
cli_args = cli_args.split(" ") if cli_args else []
with mock.patch("sys.argv", ["any.py"] + cli_args):
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser.add_lightning_class_args(Trainer, None)
args = parser.parse_args()
for k, v in expected.items():
assert getattr(args, k) == v
if "tpu_cores" not in expected or _TPU_AVAILABLE:
assert Trainer.from_argparse_args(args)
@pytest.mark.parametrize(
["cli_args", "expected", "instantiate"],
[
(["--gpus", "[0, 2]"], dict(gpus=[0, 2]), False),
(["--tpu_cores=[1,3]"], dict(tpu_cores=[1, 3]), False),
(['--accumulate_grad_batches={"5":3,"10":20}'], dict(accumulate_grad_batches={5: 3, 10: 20}), True),
],
)
def test_parse_args_parsing_complex_types(cli_args, expected, instantiate):
"""Test parsing complex types."""
with mock.patch("sys.argv", ["any.py"] + cli_args):
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser.add_lightning_class_args(Trainer, None)
args = parser.parse_args()
for k, v in expected.items():
assert getattr(args, k) == v
if instantiate:
assert Trainer.from_argparse_args(args)
@pytest.mark.parametrize(["cli_args", "expected_gpu"], [("--gpus 1", [0]), ("--gpus 0,", [0]), ("--gpus 0,1", [0, 1])])
def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu):
"""Test parsing of gpus and instantiation of Trainer."""
monkeypatch.setattr("torch.cuda.device_count", lambda: 2)
cli_args = cli_args.split(" ") if cli_args else []
with mock.patch("sys.argv", ["any.py"] + cli_args):
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser.add_lightning_class_args(Trainer, None)
args = parser.parse_args()
trainer = Trainer.from_argparse_args(args)
assert trainer.data_parallel_device_ids == expected_gpu
@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="signature inspection while mocking is not working in Python < 3.7 despite autospec",
)
@pytest.mark.parametrize(
["cli_args", "extra_args"],
[
({}, {}),
(dict(logger=False), {}),
(dict(logger=False), dict(logger=True)),
(dict(logger=False), dict(enable_checkpointing=True)),
],
)
def test_init_from_argparse_args(cli_args, extra_args):
unknown_args = dict(unknown_arg=0)
# unkown args in the argparser/namespace should be ignored
with mock.patch("pytorch_lightning.Trainer.__init__", autospec=True, return_value=None) as init:
trainer = Trainer.from_argparse_args(Namespace(**cli_args, **unknown_args), **extra_args)
expected = dict(cli_args)
expected.update(extra_args) # extra args should override any cli arg
init.assert_called_with(trainer, **expected)
# passing in unknown manual args should throw an error
with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'unknown_arg'"):
Trainer.from_argparse_args(Namespace(**cli_args), **extra_args, **unknown_args)
class Model(LightningModule):
def __init__(self, model_param: int):
super().__init__()
self.model_param = model_param
def _model_builder(model_param: int) -> Model:
return Model(model_param)
def _trainer_builder(
limit_train_batches: int, fast_dev_run: bool = False, callbacks: Optional[Union[List[Callback], Callback]] = None
) -> Trainer:
return Trainer(limit_train_batches=limit_train_batches, fast_dev_run=fast_dev_run, callbacks=callbacks)
@pytest.mark.parametrize(["trainer_class", "model_class"], [(Trainer, Model), (_trainer_builder, _model_builder)])
def test_lightning_cli(trainer_class, model_class, monkeypatch):
"""Test that LightningCLI correctly instantiates model, trainer and calls fit."""
expected_model = dict(model_param=7)
expected_trainer = dict(limit_train_batches=100)
def fit(trainer, model):
for k, v in expected_model.items():
assert getattr(model, k) == v
for k, v in expected_trainer.items():
assert getattr(trainer, k) == v
save_callback = [x for x in trainer.callbacks if isinstance(x, SaveConfigCallback)]
assert len(save_callback) == 1
save_callback[0].on_train_start(trainer, model)
def on_train_start(callback, trainer, _):
config_dump = callback.parser.dump(callback.config, skip_none=False)
for k, v in expected_model.items():
assert f" {k}: {v}" in config_dump
for k, v in expected_trainer.items():
assert f" {k}: {v}" in config_dump
trainer.ran_asserts = True
monkeypatch.setattr(Trainer, "fit", fit)
monkeypatch.setattr(SaveConfigCallback, "on_train_start", on_train_start)
with mock.patch("sys.argv", ["any.py", "fit", "--model.model_param=7", "--trainer.limit_train_batches=100"]):
cli = LightningCLI(model_class, trainer_class=trainer_class, save_config_callback=SaveConfigCallback)
assert hasattr(cli.trainer, "ran_asserts") and cli.trainer.ran_asserts
def test_lightning_cli_args_callbacks(tmpdir):
callbacks = [
dict(
class_path="pytorch_lightning.callbacks.LearningRateMonitor",
init_args=dict(logging_interval="epoch", log_momentum=True),
),
dict(class_path="pytorch_lightning.callbacks.ModelCheckpoint", init_args=dict(monitor="NAME")),
]
class TestModel(BoringModel):
def on_fit_start(self):
callback = [c for c in self.trainer.callbacks if isinstance(c, LearningRateMonitor)]
assert len(callback) == 1
assert callback[0].logging_interval == "epoch"
assert callback[0].log_momentum is True
callback = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
assert len(callback) == 1
assert callback[0].monitor == "NAME"
self.trainer.ran_asserts = True
with mock.patch("sys.argv", ["any.py", "fit", f"--trainer.callbacks={json.dumps(callbacks)}"]):
cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True))
assert cli.trainer.ran_asserts
@pytest.mark.parametrize("run", (False, True))
def test_lightning_cli_configurable_callbacks(tmpdir, run):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_lightning_class_args(LearningRateMonitor, "learning_rate_monitor")
def fit(self, **_):
pass
cli_args = ["fit"] if run else []
cli_args += [f"--trainer.default_root_dir={tmpdir}", "--learning_rate_monitor.logging_interval=epoch"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(BoringModel, run=run)
callback = [c for c in cli.trainer.callbacks if isinstance(c, LearningRateMonitor)]
assert len(callback) == 1
assert callback[0].logging_interval == "epoch"
def test_lightning_cli_args_cluster_environments(tmpdir):
plugins = [dict(class_path="pytorch_lightning.plugins.environments.SLURMEnvironment")]
class TestModel(BoringModel):
def on_fit_start(self):
# Ensure SLURMEnvironment is set, instead of default LightningEnvironment
assert isinstance(self.trainer._accelerator_connector._cluster_environment, SLURMEnvironment)
self.trainer.ran_asserts = True
with mock.patch("sys.argv", ["any.py", "fit", f"--trainer.plugins={json.dumps(plugins)}"]):
cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True))
assert cli.trainer.ran_asserts
def test_lightning_cli_args(tmpdir):
cli_args = [
"fit",
f"--data.data_dir={tmpdir}",
f"--trainer.default_root_dir={tmpdir}",
"--trainer.max_epochs=1",
"--trainer.enable_model_summary=False",
"--seed_everything=1234",
]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(BoringModel, BoringDataModule, trainer_defaults={"callbacks": [LearningRateMonitor()]})
config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml"
assert os.path.isfile(config_path)
with open(config_path) as f:
loaded_config = yaml.safe_load(f.read())
loaded_config = loaded_config["fit"]
cli_config = cli.config["fit"].as_dict()
assert cli_config["seed_everything"] == 1234
assert "model" not in loaded_config and "model" not in cli_config # no arguments to include
assert loaded_config["data"] == cli_config["data"]
assert loaded_config["trainer"] == cli_config["trainer"]
def test_lightning_cli_save_config_cases(tmpdir):
config_path = tmpdir / "config.yaml"
cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.logger=False", "--trainer.fast_dev_run=1"]
# With fast_dev_run!=False config should not be saved
with mock.patch("sys.argv", ["any.py"] + cli_args):
LightningCLI(BoringModel)
assert not os.path.isfile(config_path)
# With fast_dev_run==False config should be saved
cli_args[-1] = "--trainer.max_epochs=1"
with mock.patch("sys.argv", ["any.py"] + cli_args):
LightningCLI(BoringModel)
assert os.path.isfile(config_path)
# If run again on same directory exception should be raised since config file already exists
with mock.patch("sys.argv", ["any.py"] + cli_args), pytest.raises(RuntimeError):
LightningCLI(BoringModel)
def test_lightning_cli_config_and_subclass_mode(tmpdir):
input_config = {
"fit": {
"model": {"class_path": "tests.helpers.BoringModel"},
"data": {"class_path": "tests.helpers.BoringDataModule", "init_args": {"data_dir": str(tmpdir)}},
"trainer": {"default_root_dir": str(tmpdir), "max_epochs": 1, "enable_model_summary": False},
}
}
config_path = tmpdir / "config.yaml"
with open(config_path, "w") as f:
f.write(yaml.dump(input_config))
with mock.patch("sys.argv", ["any.py", "--config", str(config_path)]):
cli = LightningCLI(
BoringModel,
BoringDataModule,
subclass_mode_model=True,
subclass_mode_data=True,
trainer_defaults={"callbacks": LearningRateMonitor()},
)
config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml"
assert os.path.isfile(config_path)
with open(config_path) as f:
loaded_config = yaml.safe_load(f.read())
loaded_config = loaded_config["fit"]
cli_config = cli.config["fit"].as_dict()
assert loaded_config["model"] == cli_config["model"]
assert loaded_config["data"] == cli_config["data"]
assert loaded_config["trainer"] == cli_config["trainer"]
def any_model_any_data_cli():
LightningCLI(LightningModule, LightningDataModule, subclass_mode_model=True, subclass_mode_data=True)
def test_lightning_cli_help():
cli_args = ["any.py", "fit", "--help"]
out = StringIO()
with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit):
any_model_any_data_cli()
out = out.getvalue()
assert "--print_config" in out
assert "--config" in out
assert "--seed_everything" in out
assert "--model.help" in out
assert "--data.help" in out
skip_params = {"self"}
for param in inspect.signature(Trainer.__init__).parameters.keys():
if param not in skip_params:
assert f"--trainer.{param}" in out
cli_args = ["any.py", "fit", "--data.help=tests.helpers.BoringDataModule"]
out = StringIO()
with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit):
any_model_any_data_cli()
assert "--data.init_args.data_dir" in out.getvalue()
def test_lightning_cli_print_config():
cli_args = [
"any.py",
"predict",
"--seed_everything=1234",
"--model=tests.helpers.BoringModel",
"--data=tests.helpers.BoringDataModule",
"--print_config",
]
out = StringIO()
with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit):
any_model_any_data_cli()
outval = yaml.safe_load(out.getvalue())
assert outval["seed_everything"] == 1234
assert outval["model"]["class_path"] == "tests.helpers.BoringModel"
assert outval["data"]["class_path"] == "tests.helpers.BoringDataModule"
assert outval["ckpt_path"] is None
def test_lightning_cli_submodules(tmpdir):
class MainModule(BoringModel):
def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1):
super().__init__()
self.submodule1 = submodule1
self.submodule2 = submodule2
config = """model:
main_param: 2
submodule1:
class_path: tests.helpers.BoringModel
submodule2:
class_path: tests.helpers.BoringModel
"""
config_path = tmpdir / "config.yaml"
with open(config_path, "w") as f:
f.write(config)
cli_args = [f"--trainer.default_root_dir={tmpdir}", f"--config={str(config_path)}"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(MainModule, run=False)
assert cli.config["model"]["main_param"] == 2
assert isinstance(cli.model.submodule1, BoringModel)
assert isinstance(cli.model.submodule2, BoringModel)
@pytest.mark.skipif(torchvision_version < version.parse("0.8.0"), reason="torchvision>=0.8.0 is required")
def test_lightning_cli_torch_modules(tmpdir):
class TestModule(BoringModel):
def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None):
super().__init__()
self.activation = activation
self.transform = transform
config = """model:
activation:
class_path: torch.nn.LeakyReLU
init_args:
negative_slope: 0.2
transform:
- class_path: torchvision.transforms.Resize
init_args:
size: 64
- class_path: torchvision.transforms.CenterCrop
init_args:
size: 64
"""
config_path = tmpdir / "config.yaml"
with open(config_path, "w") as f:
f.write(config)
cli_args = [f"--trainer.default_root_dir={tmpdir}", f"--config={str(config_path)}"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(TestModule, run=False)
assert isinstance(cli.model.activation, torch.nn.LeakyReLU)
assert cli.model.activation.negative_slope == 0.2
assert len(cli.model.transform) == 2
assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform)
class BoringModelRequiredClasses(BoringModel):
def __init__(self, num_classes: int, batch_size: int = 8):
super().__init__()
self.num_classes = num_classes
self.batch_size = batch_size
class BoringDataModuleBatchSizeAndClasses(BoringDataModule):
def __init__(self, batch_size: int = 8):
super().__init__()
self.batch_size = batch_size
self.num_classes = 5 # only available after instantiation
def test_lightning_cli_link_arguments(tmpdir):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("data.batch_size", "model.batch_size")
parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate")
cli_args = [f"--trainer.default_root_dir={tmpdir}", "--data.batch_size=12"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, run=False)
assert cli.model.batch_size == 12
assert cli.model.num_classes == 5
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("data.batch_size", "model.init_args.batch_size")
parser.link_arguments("data.num_classes", "model.init_args.num_classes", apply_on="instantiate")
cli_args[-1] = "--model=tests.utilities.test_cli.BoringModelRequiredClasses"
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(
BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, subclass_mode_model=True, run=False
)
assert cli.model.batch_size == 8
assert cli.model.num_classes == 5
class EarlyExitTestModel(BoringModel):
def on_fit_start(self):
raise MisconfigurationException("Error on fit start")
@pytest.mark.parametrize("logger", (False, True))
@pytest.mark.parametrize(
"trainer_kwargs",
(
dict(strategy="ddp_spawn"),
dict(strategy="ddp"),
pytest.param({"tpu_cores": 1}, marks=RunIf(tpu=True)),
),
)
def test_cli_distributed_save_config_callback(tmpdir, logger, trainer_kwargs):
with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises(
MisconfigurationException, match=r"Error on fit start"
):
LightningCLI(
EarlyExitTestModel,
trainer_defaults={
"default_root_dir": str(tmpdir),
"logger": logger,
"max_steps": 1,
"max_epochs": 1,
**trainer_kwargs,
},
)
if logger:
config_dir = tmpdir / "lightning_logs"
# no more version dirs should get created
assert os.listdir(config_dir) == ["version_0"]
config_path = config_dir / "version_0" / "config.yaml"
else:
config_path = tmpdir / "config.yaml"
assert os.path.isfile(config_path)
def test_cli_config_overwrite(tmpdir):
trainer_defaults = {"default_root_dir": str(tmpdir), "logger": False, "max_steps": 1, "max_epochs": 1}
argv = ["any.py", "fit"]
with mock.patch("sys.argv", argv):
LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
with mock.patch("sys.argv", argv), pytest.raises(RuntimeError, match="Aborting to avoid overwriting"):
LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
with mock.patch("sys.argv", argv):
LightningCLI(BoringModel, save_config_overwrite=True, trainer_defaults=trainer_defaults)
@pytest.mark.parametrize("run", (False, True))
def test_lightning_cli_optimizer(tmpdir, run):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(torch.optim.Adam)
match = (
"BoringModel.configure_optimizers` will be overridden by "
"`MyLightningCLI.add_configure_optimizers_method_to_model`"
)
argv = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.fast_dev_run=1"] if run else []
with mock.patch("sys.argv", ["any.py"] + argv), pytest.warns(UserWarning, match=match):
cli = MyLightningCLI(BoringModel, run=run)
assert cli.model.configure_optimizers is not BoringModel.configure_optimizers
if not run:
optimizer = cli.model.configure_optimizers()
assert isinstance(optimizer, torch.optim.Adam)
else:
assert len(cli.trainer.optimizers) == 1
assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam)
assert len(cli.trainer.lr_schedulers) == 0
def test_lightning_cli_optimizer_and_lr_scheduler(tmpdir):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(torch.optim.Adam)
parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR)
cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.fast_dev_run=1", "--lr_scheduler.gamma=0.8"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(BoringModel)
assert cli.model.configure_optimizers is not BoringModel.configure_optimizers
assert len(cli.trainer.optimizers) == 1
assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam)
assert len(cli.trainer.lr_schedulers) == 1
assert isinstance(cli.trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.ExponentialLR)
assert cli.trainer.lr_schedulers[0]["scheduler"].gamma == 0.8
def test_lightning_cli_optimizer_and_lr_scheduler_subclasses(tmpdir):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args((torch.optim.SGD, torch.optim.Adam))
parser.add_lr_scheduler_args((torch.optim.lr_scheduler.StepLR, torch.optim.lr_scheduler.ExponentialLR))
optimizer_arg = dict(class_path="torch.optim.Adam", init_args=dict(lr=0.01))
lr_scheduler_arg = dict(class_path="torch.optim.lr_scheduler.StepLR", init_args=dict(step_size=50))
cli_args = [
"fit",
f"--trainer.default_root_dir={tmpdir}",
"--trainer.max_epochs=1",
f"--optimizer={json.dumps(optimizer_arg)}",
f"--lr_scheduler={json.dumps(lr_scheduler_arg)}",
]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(BoringModel)
assert len(cli.trainer.optimizers) == 1
assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam)
assert len(cli.trainer.lr_schedulers) == 1
assert isinstance(cli.trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.StepLR)
assert cli.trainer.lr_schedulers[0]["scheduler"].step_size == 50
@pytest.mark.parametrize("use_registries", [False, True])
def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_registries, tmpdir):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(
OPTIMIZER_REGISTRY.classes if use_registries else torch.optim.Adam,
nested_key="optim1",
link_to="model.optim1",
)
parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2")
parser.add_lr_scheduler_args(
LR_SCHEDULER_REGISTRY.classes if use_registries else torch.optim.lr_scheduler.ExponentialLR,
link_to="model.scheduler",
)
class TestModel(BoringModel):
def __init__(self, optim1: dict, optim2: dict, scheduler: dict):
super().__init__()
self.optim1 = instantiate_class(self.parameters(), optim1)
self.optim2 = instantiate_class(self.parameters(), optim2)
self.scheduler = instantiate_class(self.optim1, scheduler)
cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", "--lr_scheduler.gamma=0.2"]
if use_registries:
cli_args += [
"--optim1",
"Adam",
"--optim1.weight_decay",
"0.001",
"--optim2=SGD",
"--optim2.lr=0.01",
"--lr_scheduler=ExponentialLR",
]
else:
cli_args += ["--optim2.class_path=torch.optim.SGD", "--optim2.init_args.lr=0.01"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(TestModel)
assert isinstance(cli.model.optim1, torch.optim.Adam)
assert isinstance(cli.model.optim2, torch.optim.SGD)
assert cli.model.optim2.param_groups[0]["lr"] == 0.01
assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)
@pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn])
def test_lightning_cli_trainer_fn(fn):
class TestCLI(LightningCLI):
def __init__(self, *args, **kwargs):
self.called = []
super().__init__(*args, **kwargs)
def before_fit(self):
self.called.append("before_fit")
def fit(self, **_):
self.called.append("fit")
def after_fit(self):
self.called.append("after_fit")
def before_validate(self):
self.called.append("before_validate")
def validate(self, **_):
self.called.append("validate")
def after_validate(self):
self.called.append("after_validate")
def before_test(self):
self.called.append("before_test")
def test(self, **_):
self.called.append("test")
def after_test(self):
self.called.append("after_test")
def before_predict(self):
self.called.append("before_predict")
def predict(self, **_):
self.called.append("predict")
def after_predict(self):
self.called.append("after_predict")
def before_tune(self):
self.called.append("before_tune")
def tune(self, **_):
self.called.append("tune")
def after_tune(self):
self.called.append("after_tune")
with mock.patch("sys.argv", ["any.py", fn]):
cli = TestCLI(BoringModel)
assert cli.called == [f"before_{fn}", fn, f"after_{fn}"]
def test_lightning_cli_subcommands():
subcommands = LightningCLI.subcommands()
trainer = Trainer()
for subcommand, exclude in subcommands.items():
fn = getattr(trainer, subcommand)
parameters = list(inspect.signature(fn).parameters)
for e in exclude:
# if this fails, it's because the parameter has been removed from the associated `Trainer` function
# and the `LightningCLI` subcommand exclusion list needs to be updated
assert e in parameters
def test_lightning_cli_custom_subcommand():
class TestTrainer(Trainer):
def foo(self, model: LightningModule, x: int, y: float = 1.0):
"""Sample extra function.
Args:
model: A model
x: The x
y: The y
"""
class TestCLI(LightningCLI):
@staticmethod
def subcommands():
subcommands = LightningCLI.subcommands()
subcommands["foo"] = {"model"}
return subcommands
out = StringIO()
with mock.patch("sys.argv", ["any.py", "-h"]), redirect_stdout(out), pytest.raises(SystemExit):
TestCLI(BoringModel, trainer_class=TestTrainer)
out = out.getvalue()
assert "Sample extra function." in out
assert "{fit,validate,test,predict,tune,foo}" in out
out = StringIO()
with mock.patch("sys.argv", ["any.py", "foo", "-h"]), redirect_stdout(out), pytest.raises(SystemExit):
TestCLI(BoringModel, trainer_class=TestTrainer)
out = out.getvalue()
assert "A model" not in out
assert "Sample extra function:" in out
assert "--x X" in out
assert "The x (required, type: int)" in out
assert "--y Y" in out
assert "The y (type: float, default: 1.0)" in out
def test_lightning_cli_run():
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(BoringModel, run=False)
assert cli.trainer.global_step == 0
assert isinstance(cli.trainer, Trainer)
assert isinstance(cli.model, LightningModule)
with mock.patch("sys.argv", ["any.py", "fit"]):
cli = LightningCLI(BoringModel, trainer_defaults={"max_steps": 1, "max_epochs": 1})
assert cli.trainer.global_step == 1
assert isinstance(cli.trainer, Trainer)
assert isinstance(cli.model, LightningModule)
@OPTIMIZER_REGISTRY
class CustomAdam(torch.optim.Adam):
pass
@LR_SCHEDULER_REGISTRY
class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
pass
@CALLBACK_REGISTRY
class CustomCallback(Callback):
pass
def test_registries(tmpdir):
assert "SGD" in OPTIMIZER_REGISTRY.names
assert "RMSprop" in OPTIMIZER_REGISTRY.names
assert "CustomAdam" in OPTIMIZER_REGISTRY.names
assert "CosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names
assert "CosineAnnealingWarmRestarts" in LR_SCHEDULER_REGISTRY.names
assert "CustomCosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names
assert "EarlyStopping" in CALLBACK_REGISTRY.names
assert "CustomCallback" in CALLBACK_REGISTRY.names
with pytest.raises(MisconfigurationException, match="is already present in the registry"):
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer)
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer, override=True)
# test `_Registry.__call__` returns the class
assert isinstance(CustomCallback(), CustomCallback)
@MODEL_REGISTRY
class TestModel(BoringModel):
def __init__(self, foo, bar=5):
super().__init__()
self.foo = foo
self.bar = bar
MODEL_REGISTRY(cls=BoringModel)
def test_lightning_cli_model_choices():
with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch(
"pytorch_lightning.Trainer._fit_impl"
) as run:
cli = LightningCLI(trainer_defaults={"fast_dev_run": 1})
assert isinstance(cli.model, BoringModel)
run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY)
with mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]):
cli = LightningCLI(run=False)
assert isinstance(cli.model, TestModel)
assert cli.model.foo == 123
assert cli.model.bar == 5
@DATAMODULE_REGISTRY
class MyDataModule(BoringDataModule):
def __init__(self, foo, bar=5):
super().__init__()
self.foo = foo
self.bar = bar
DATAMODULE_REGISTRY(cls=BoringDataModule)
def test_lightning_cli_datamodule_choices():
# with set model
with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch(
"pytorch_lightning.Trainer._fit_impl"
) as run:
cli = LightningCLI(BoringModel, trainer_defaults={"fast_dev_run": 1})
assert isinstance(cli.datamodule, BoringDataModule)
run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule, ANY)
with mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]):
cli = LightningCLI(BoringModel, run=False)
assert isinstance(cli.datamodule, MyDataModule)
assert cli.datamodule.foo == 123
assert cli.datamodule.bar == 5
# with configurable model
with mock.patch("sys.argv", ["any.py", "fit", "--model", "BoringModel", "--data=BoringDataModule"]), mock.patch(
"pytorch_lightning.Trainer._fit_impl"
) as run:
cli = LightningCLI(trainer_defaults={"fast_dev_run": 1})
assert isinstance(cli.model, BoringModel)
assert isinstance(cli.datamodule, BoringDataModule)
run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule, ANY)
with mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]):
cli = LightningCLI(run=False)
assert isinstance(cli.model, BoringModel)
assert isinstance(cli.datamodule, MyDataModule)
assert len(DATAMODULE_REGISTRY) # needs a value initially added
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(BoringModel, run=False)
# data was not passed but we are adding it automatically because there are datamodules registered
assert "data" in cli.parser.groups
assert not hasattr(cli.parser.groups["data"], "group_class")
with mock.patch("sys.argv", ["any.py"]), mock.patch.dict(DATAMODULE_REGISTRY, clear=True):
cli = LightningCLI(BoringModel, run=False)
# no registered classes so not added automatically
assert "data" not in cli.parser.groups
assert len(DATAMODULE_REGISTRY) # check state was not modified
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(BoringModel, BoringDataModule, run=False)
# since we are passing the DataModule, that's whats added to the parser
assert cli.parser.groups["data"].group_class is BoringDataModule
@pytest.mark.parametrize("use_class_path_callbacks", [False, True])
def test_registries_resolution(use_class_path_callbacks):
"""This test validates registries are used when simplified command line are being used."""
cli_args = [
"--optimizer",
"Adam",
"--optimizer.lr",
"0.0001",
"--trainer.callbacks=LearningRateMonitor",
"--trainer.callbacks.logging_interval=epoch",
"--trainer.callbacks.log_momentum=True",
"--model=BoringModel",
"--trainer.callbacks=ModelCheckpoint",
"--trainer.callbacks.monitor=loss",
"--lr_scheduler",
"StepLR",
"--lr_scheduler.step_size=50",
]
extras = []
if use_class_path_callbacks:
callbacks = [
{"class_path": "pytorch_lightning.callbacks.Callback"},
{"class_path": "pytorch_lightning.callbacks.Callback", "init_args": {}},
]
cli_args += [f"--trainer.callbacks={json.dumps(callbacks)}"]
extras = [Callback, Callback]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(run=False)
assert isinstance(cli.model, BoringModel)
optimizers, lr_scheduler = cli.model.configure_optimizers()
assert isinstance(optimizers[0], torch.optim.Adam)
assert optimizers[0].param_groups[0]["lr"] == 0.0001
assert lr_scheduler[0].step_size == 50
callback_types = [type(c) for c in cli.trainer.callbacks]
expected = [LearningRateMonitor, SaveConfigCallback, ModelCheckpoint] + extras
assert all(t in callback_types for t in expected)
def test_argv_transformation_noop():
base = ["any.py", "--trainer.max_epochs=1"]
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", base)
assert argv == base
def test_argv_transformation_single_callback():
base = ["any.py", "--trainer.max_epochs=1"]
input = base + ["--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=val_loss"]
callbacks = [
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_loss"},
}
]
expected = base + ["--trainer.callbacks", str(callbacks)]
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
assert argv == expected
def test_argv_transformation_multiple_callbacks():
base = ["any.py", "--trainer.max_epochs=1"]
input = base + [
"--trainer.callbacks=ModelCheckpoint",
"--trainer.callbacks.monitor=val_loss",
"--trainer.callbacks=ModelCheckpoint",
"--trainer.callbacks.monitor=val_acc",
]
callbacks = [
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_loss"},
},
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_acc"},
},
]
expected = base + ["--trainer.callbacks", str(callbacks)]
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
assert argv == expected
def test_argv_transformation_multiple_callbacks_with_config():
base = ["any.py", "--trainer.max_epochs=1"]
nested_key = "trainer.callbacks"
input = base + [
f"--{nested_key}=ModelCheckpoint",
f"--{nested_key}.monitor=val_loss",
f"--{nested_key}=ModelCheckpoint",
f"--{nested_key}.monitor=val_acc",
f"--{nested_key}=[{{'class_path': 'pytorch_lightning.callbacks.Callback'}}]",
]
callbacks = [
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_loss"},
},
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_acc"},
},
{"class_path": "pytorch_lightning.callbacks.Callback"},
]
expected = base + ["--trainer.callbacks", str(callbacks)]
nested_key = "trainer.callbacks"
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, nested_key, input)
assert argv == expected
@pytest.mark.parametrize(
["args", "expected", "nested_key", "registry"],
[
(
["--optimizer", "Adadelta"],
{"class_path": "torch.optim.adadelta.Adadelta", "init_args": {}},
"optimizer",
OPTIMIZER_REGISTRY,
),
(
["--optimizer", "Adadelta", "--optimizer.lr", "10"],
{"class_path": "torch.optim.adadelta.Adadelta", "init_args": {"lr": "10"}},
"optimizer",
OPTIMIZER_REGISTRY,
),
(
["--lr_scheduler", "OneCycleLR"],
{"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {}},
"lr_scheduler",
LR_SCHEDULER_REGISTRY,
),
(
["--lr_scheduler", "OneCycleLR", "--lr_scheduler.anneal_strategy=linear"],
{"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {"anneal_strategy": "linear"}},
"lr_scheduler",
LR_SCHEDULER_REGISTRY,
),
],
)
def test_argv_transformations_with_optimizers_and_lr_schedulers(args, expected, nested_key, registry):
base = ["any.py", "--trainer.max_epochs=1"]
argv = base + args
new_argv = LightningArgumentParser._convert_argv_issue_84(registry.classes, nested_key, argv)
assert new_argv == base + [f"--{nested_key}", str(expected)]
def test_optimizers_and_lr_schedulers_reload(tmpdir):
base = ["any.py", "--trainer.max_epochs=1"]
input = base + [
"--lr_scheduler",
"OneCycleLR",
"--lr_scheduler.total_steps=10",
"--lr_scheduler.max_lr=1",
"--optimizer",
"Adam",
"--optimizer.lr=0.1",
]
# save config
out = StringIO()
with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit):
LightningCLI(BoringModel, run=False)
# validate yaml
yaml_config = out.getvalue()
dict_config = yaml.safe_load(yaml_config)
assert dict_config["optimizer"]["class_path"] == "torch.optim.adam.Adam"
assert dict_config["optimizer"]["init_args"]["lr"] == 0.1
assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR"
# reload config
yaml_config_file = tmpdir / "config.yaml"
yaml_config_file.write_text(yaml_config, "utf-8")
with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]):
LightningCLI(BoringModel, run=False)
def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload(tmpdir):
class TestLightningCLI(LightningCLI):
def __init__(self, *args):
super().__init__(*args, run=False)
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes, nested_key="opt1", link_to="model.opt1_config")
parser.add_optimizer_args(
(torch.optim.ASGD, torch.optim.SGD), nested_key="opt2", link_to="model.opt2_config"
)
parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes, link_to="model.sch_config")
parser.add_argument("--something", type=str, nargs="+")
class TestModel(BoringModel):
def __init__(self, opt1_config: dict, opt2_config: dict, sch_config: dict):
super().__init__()
self.opt1_config = opt1_config
self.opt2_config = opt2_config
self.sch_config = sch_config
opt1 = instantiate_class(self.parameters(), opt1_config)
assert isinstance(opt1, torch.optim.Adam)
opt2 = instantiate_class(self.parameters(), opt2_config)
assert isinstance(opt2, torch.optim.ASGD)
sch = instantiate_class(opt1, sch_config)
assert isinstance(sch, torch.optim.lr_scheduler.OneCycleLR)
base = ["any.py", "--trainer.max_epochs=1"]
input = base + [
"--lr_scheduler",
"OneCycleLR",
"--lr_scheduler.total_steps=10",
"--lr_scheduler.max_lr=1",
"--opt1",
"Adam",
"--opt2.lr=0.1",
"--opt2",
"ASGD",
"--lr_scheduler.anneal_strategy=linear",
"--something",
"a",
"b",
"c",
]
# save config
out = StringIO()
with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit):
TestLightningCLI(TestModel)
# validate yaml
yaml_config = out.getvalue()
dict_config = yaml.safe_load(yaml_config)
assert dict_config["opt1"]["class_path"] == "torch.optim.adam.Adam"
assert dict_config["opt2"]["class_path"] == "torch.optim.asgd.ASGD"
assert dict_config["opt2"]["init_args"]["lr"] == 0.1
assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR"
assert dict_config["lr_scheduler"]["init_args"]["anneal_strategy"] == "linear"
assert dict_config["something"] == ["a", "b", "c"]
# reload config
yaml_config_file = tmpdir / "config.yaml"
yaml_config_file.write_text(yaml_config, "utf-8")
with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]):
cli = TestLightningCLI(TestModel)
assert cli.model.opt1_config["class_path"] == "torch.optim.adam.Adam"
assert cli.model.opt2_config["class_path"] == "torch.optim.asgd.ASGD"
assert cli.model.opt2_config["init_args"]["lr"] == 0.1
assert cli.model.sch_config["class_path"] == "torch.optim.lr_scheduler.OneCycleLR"
assert cli.model.sch_config["init_args"]["anneal_strategy"] == "linear"
@RunIf(min_python="3.7.3") # bpo-17185: `autospec=True` and `inspect.signature` do not play well
def test_lightning_cli_config_with_subcommand():
config = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}}
with mock.patch("sys.argv", ["any.py", f"--config={config}"]), mock.patch(
"pytorch_lightning.Trainer.test", autospec=True
) as test_mock:
cli = LightningCLI(BoringModel)
test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar")
assert cli.trainer.limit_test_batches == 1
@RunIf(min_python="3.7.3")
def test_lightning_cli_config_before_subcommand():
config = {
"validate": {"trainer": {"limit_val_batches": 1}, "verbose": False, "ckpt_path": "barfoo"},
"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"},
}
with mock.patch("sys.argv", ["any.py", f"--config={config}", "test"]), mock.patch(
"pytorch_lightning.Trainer.test", autospec=True
) as test_mock:
cli = LightningCLI(BoringModel)
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar")
assert cli.trainer.limit_test_batches == 1
with mock.patch("sys.argv", ["any.py", f"--config={config}", "validate"]), mock.patch(
"pytorch_lightning.Trainer.validate", autospec=True
) as validate_mock:
cli = LightningCLI(BoringModel)
validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo")
assert cli.trainer.limit_val_batches == 1
@RunIf(min_python="3.7.3")
def test_lightning_cli_config_before_subcommand_two_configs():
config1 = {"validate": {"trainer": {"limit_val_batches": 1}, "verbose": False, "ckpt_path": "barfoo"}}
config2 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}}
with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "test"]), mock.patch(
"pytorch_lightning.Trainer.test", autospec=True
) as test_mock:
cli = LightningCLI(BoringModel)
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar")
assert cli.trainer.limit_test_batches == 1
with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "validate"]), mock.patch(
"pytorch_lightning.Trainer.validate", autospec=True
) as validate_mock:
cli = LightningCLI(BoringModel)
validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo")
assert cli.trainer.limit_val_batches == 1
@RunIf(min_python="3.7.3")
def test_lightning_cli_config_after_subcommand():
config = {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}
with mock.patch("sys.argv", ["any.py", "test", f"--config={config}"]), mock.patch(
"pytorch_lightning.Trainer.test", autospec=True
) as test_mock:
cli = LightningCLI(BoringModel)
test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar")
assert cli.trainer.limit_test_batches == 1
@RunIf(min_python="3.7.3")
def test_lightning_cli_config_before_and_after_subcommand():
config1 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}}
config2 = {"trainer": {"fast_dev_run": 1}, "verbose": False, "ckpt_path": "foobar"}
with mock.patch("sys.argv", ["any.py", f"--config={config1}", "test", f"--config={config2}"]), mock.patch(
"pytorch_lightning.Trainer.test", autospec=True
) as test_mock:
cli = LightningCLI(BoringModel)
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar")
assert cli.trainer.limit_test_batches == 1
assert cli.trainer.fast_dev_run == 1
def test_lightning_cli_parse_kwargs_with_subcommands(tmpdir):
fit_config = {"trainer": {"limit_train_batches": 2}}
fit_config_path = tmpdir / "fit.yaml"
fit_config_path.write_text(str(fit_config), "utf8")
validate_config = {"trainer": {"limit_val_batches": 3}}
validate_config_path = tmpdir / "validate.yaml"
validate_config_path.write_text(str(validate_config), "utf8")
parser_kwargs = {
"fit": {"default_config_files": [str(fit_config_path)]},
"validate": {"default_config_files": [str(validate_config_path)]},
}
with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch(
"pytorch_lightning.Trainer.fit", autospec=True
) as fit_mock:
cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs)
fit_mock.assert_called()
assert cli.trainer.limit_train_batches == 2
assert cli.trainer.limit_val_batches == 1.0
with mock.patch("sys.argv", ["any.py", "validate"]), mock.patch(
"pytorch_lightning.Trainer.validate", autospec=True
) as validate_mock:
cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs)
validate_mock.assert_called()
assert cli.trainer.limit_train_batches == 1.0
assert cli.trainer.limit_val_batches == 3
def test_lightning_cli_reinstantiate_trainer():
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(BoringModel, run=False)
assert cli.trainer.max_epochs == 1000
class TestCallback(Callback):
...
# make sure a new trainer can be easily created
trainer = cli.instantiate_trainer(max_epochs=123, callbacks=[TestCallback()])
# the new config is used
assert trainer.max_epochs == 123
assert {c.__class__ for c in trainer.callbacks} == {c.__class__ for c in cli.trainer.callbacks}.union(
{TestCallback}
)
# the existing config is not updated
assert cli.config_init["trainer"]["max_epochs"] is None
def test_cli_configure_optimizers_warning(tmpdir):
match = "configure_optimizers` will be overridden by `LightningCLI"
with mock.patch("sys.argv", ["any.py"]), no_warning_call(UserWarning, match=match):
LightningCLI(BoringModel, run=False)
with mock.patch("sys.argv", ["any.py", "--optimizer=Adam"]), pytest.warns(UserWarning, match=match):
LightningCLI(BoringModel, run=False)