lightning/tests/tests_pytorch/test_cli.py

1620 lines
62 KiB
Python

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import inspect
import json
import operator
import os
import sys
from contextlib import contextmanager, ExitStack, redirect_stdout
from io import StringIO
from pathlib import Path
from typing import Callable, List, Optional, Union
from unittest import mock
from unittest.mock import ANY
import pytest
import torch
import yaml
from lightning_utilities import compare_version
from lightning_utilities.test.warning import no_warning_call
from tensorboard.backend.event_processing import event_accumulator
from tensorboard.plugins.hparams.plugin_data_pb2 import HParamsPluginData
from torch.optim import SGD
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from lightning.fabric.plugins.environments import SLURMEnvironment
from lightning.pytorch import __version__, Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.cli import (
_JSONARGPARSE_SIGNATURES_AVAILABLE,
instantiate_class,
LightningArgumentParser,
LightningCLI,
LRSchedulerCallable,
LRSchedulerTypeTuple,
OptimizerCallable,
SaveConfigCallback,
)
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.loggers.comet import _COMET_AVAILABLE
from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE
from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.trainer.states import TrainerFn
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
from tests_pytorch.helpers.runif import RunIf
if _JSONARGPARSE_SIGNATURES_AVAILABLE:
from jsonargparse import lazy_instance, Namespace
else:
from argparse import Namespace
@contextmanager
def mock_subclasses(baseclass, *subclasses):
"""Mocks baseclass so that it only has the given child subclasses."""
with ExitStack() as stack:
mgr = mock.patch.object(baseclass, "__subclasses__", return_value=[*subclasses])
stack.enter_context(mgr)
for mgr in [mock.patch.object(s, "__subclasses__", return_value=[]) for s in subclasses]:
stack.enter_context(mgr)
yield None
@pytest.fixture()
def cleandir(tmp_path, monkeypatch):
monkeypatch.chdir(tmp_path)
return
@pytest.fixture(autouse=True)
def ensure_cleandir():
yield
# make sure tests don't leave configuration files
assert not glob.glob("*.yaml")
@pytest.mark.parametrize("cli_args", [["--callbacks=1", "--logger"], ["--foo", "--bar=1"]])
def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
"""Asserts error raised in case of passing not default cli arguments."""
class _UnkArgError(Exception):
pass
def _raise():
raise _UnkArgError
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
parser.add_lightning_class_args(Trainer, None)
monkeypatch.setattr(parser, "exit", lambda *args: _raise(), raising=True)
with pytest.raises(_UnkArgError):
parser.parse_args(cli_args)
class Model(LightningModule):
def __init__(self, model_param: int):
super().__init__()
self.model_param = model_param
def _model_builder(model_param: int) -> Model:
return Model(model_param)
def _trainer_builder(
limit_train_batches: int, fast_dev_run: bool = False, callbacks: Optional[Union[List[Callback], Callback]] = None
) -> Trainer:
return Trainer(limit_train_batches=limit_train_batches, fast_dev_run=fast_dev_run, callbacks=callbacks)
@pytest.mark.parametrize(("trainer_class", "model_class"), [(Trainer, Model), (_trainer_builder, _model_builder)])
def test_lightning_cli(trainer_class, model_class, monkeypatch):
"""Test that LightningCLI correctly instantiates model, trainer and calls fit."""
expected_model = {"model_param": 7}
expected_trainer = {"limit_train_batches": 100}
def fit(trainer, model):
for k, v in expected_model.items():
assert getattr(model, k) == v
for k, v in expected_trainer.items():
assert getattr(trainer, k) == v
save_callback = [x for x in trainer.callbacks if isinstance(x, SaveConfigCallback)]
assert len(save_callback) == 1
save_callback[0].on_train_start(trainer, model)
def on_train_start(callback, trainer, _):
config_dump = callback.parser.dump(callback.config, skip_none=False)
for k, v in expected_model.items():
assert f" {k}: {v}" in config_dump
for k, v in expected_trainer.items():
assert f" {k}: {v}" in config_dump
trainer.ran_asserts = True
monkeypatch.setattr(Trainer, "fit", fit)
monkeypatch.setattr(SaveConfigCallback, "on_train_start", on_train_start)
with mock.patch("sys.argv", ["any.py", "fit", "--model.model_param=7", "--trainer.limit_train_batches=100"]):
cli = LightningCLI(model_class, trainer_class=trainer_class, save_config_callback=SaveConfigCallback)
assert hasattr(cli.trainer, "ran_asserts")
assert cli.trainer.ran_asserts
def test_lightning_cli_args_callbacks(cleandir):
callbacks = [
{
"class_path": "lightning.pytorch.callbacks.LearningRateMonitor",
"init_args": {"logging_interval": "epoch", "log_momentum": True},
},
{"class_path": "lightning.pytorch.callbacks.ModelCheckpoint", "init_args": {"monitor": "NAME"}},
]
class TestModel(BoringModel):
def on_fit_start(self):
callback = [c for c in self.trainer.callbacks if isinstance(c, LearningRateMonitor)]
assert len(callback) == 1
assert callback[0].logging_interval == "epoch"
assert callback[0].log_momentum is True
callback = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
assert len(callback) == 1
assert callback[0].monitor == "NAME"
self.trainer.ran_asserts = True
with mock.patch("sys.argv", ["any.py", "fit", f"--trainer.callbacks={json.dumps(callbacks)}"]):
cli = LightningCLI(TestModel, trainer_defaults={"fast_dev_run": True, "logger": CSVLogger(".")})
assert cli.trainer.ran_asserts
def test_lightning_cli_single_arg_callback():
with mock.patch("sys.argv", ["any.py", "--trainer.callbacks=DeviceStatsMonitor"]):
cli = LightningCLI(BoringModel, run=False)
assert cli.config.trainer.callbacks.class_path == "lightning.pytorch.callbacks.DeviceStatsMonitor"
assert not isinstance(cli.config_init.trainer, list)
@pytest.mark.parametrize("run", [False, True])
def test_lightning_cli_configurable_callbacks(cleandir, run):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_lightning_class_args(LearningRateMonitor, "learning_rate_monitor")
def fit(self, **_):
pass
cli_args = ["fit"] if run else []
cli_args += ["--learning_rate_monitor.logging_interval=epoch"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(BoringModel, run=run)
callback = [c for c in cli.trainer.callbacks if isinstance(c, LearningRateMonitor)]
assert len(callback) == 1
assert callback[0].logging_interval == "epoch"
def test_lightning_cli_args_cluster_environments(cleandir):
plugins = [{"class_path": "lightning.fabric.plugins.environments.SLURMEnvironment"}]
class TestModel(BoringModel):
def on_fit_start(self):
# Ensure SLURMEnvironment is set, instead of default LightningEnvironment
assert isinstance(self.trainer._accelerator_connector.cluster_environment, SLURMEnvironment)
self.trainer.ran_asserts = True
with mock.patch("sys.argv", ["any.py", "fit", f"--trainer.plugins={json.dumps(plugins)}"]):
cli = LightningCLI(TestModel, trainer_defaults={"fast_dev_run": True})
assert cli.trainer.ran_asserts
class DataDirDataModule(BoringDataModule):
def __init__(self, data_dir):
super().__init__()
def test_lightning_cli_args(cleandir):
cli_args = [
"fit",
"--data.data_dir=.",
"--trainer.max_epochs=1",
"--trainer.limit_train_batches=1",
"--trainer.limit_val_batches=0",
"--trainer.enable_model_summary=False",
"--trainer.logger=False",
"--seed_everything=1234",
]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(BoringModel, DataDirDataModule)
config_path = "config.yaml"
assert os.path.isfile(config_path)
with open(config_path) as f:
loaded_config = yaml.safe_load(f.read())
cli_config = cli.config["fit"].as_dict()
assert cli_config["seed_everything"] == 1234
assert "model" not in loaded_config
assert "model" not in cli_config
assert loaded_config["data"] == cli_config["data"]
assert loaded_config["trainer"] == cli_config["trainer"]
@pytest.mark.skipif(compare_version("jsonargparse", operator.lt, "4.21.3"), reason="vulnerability with failing imports")
def test_lightning_env_parse(cleandir):
out = StringIO()
with mock.patch("sys.argv", ["", "fit", "--help"]), redirect_stdout(out), pytest.raises(SystemExit):
LightningCLI(BoringModel, DataDirDataModule, parser_kwargs={"default_env": True})
out = out.getvalue()
assert "PL_FIT__CONFIG" in out
assert "PL_FIT__SEED_EVERYTHING" in out
assert "PL_FIT__TRAINER__LOGGER" in out
assert "PL_FIT__DATA__DATA_DIR" in out
assert "PL_FIT__CKPT_PATH" in out
env_vars = {
"PL_FIT__DATA__DATA_DIR": ".",
"PL_FIT__TRAINER__DEFAULT_ROOT_DIR": ".",
"PL_FIT__TRAINER__MAX_EPOCHS": "1",
"PL_FIT__TRAINER__LOGGER": "False",
}
with mock.patch.dict(os.environ, env_vars), mock.patch("sys.argv", ["", "fit"]):
cli = LightningCLI(BoringModel, DataDirDataModule, parser_kwargs={"default_env": True})
assert cli.config.fit.data.data_dir == "."
assert cli.config.fit.trainer.default_root_dir == "."
assert cli.config.fit.trainer.max_epochs == 1
assert cli.config.fit.trainer.logger is False
def test_lightning_cli_save_config_cases(cleandir):
config_path = "config.yaml"
cli_args = ["fit", "--trainer.logger=false", "--trainer.fast_dev_run=1"]
# With fast_dev_run!=False config should not be saved
with mock.patch("sys.argv", ["any.py"] + cli_args):
LightningCLI(BoringModel)
assert not os.path.isfile(config_path)
# With fast_dev_run==False config should be saved
cli_args[-1] = "--trainer.max_epochs=1"
with mock.patch("sys.argv", ["any.py"] + cli_args):
LightningCLI(BoringModel)
assert os.path.isfile(config_path)
# If run again on same directory exception should be raised since config file already exists
with mock.patch("sys.argv", ["any.py"] + cli_args), pytest.raises(RuntimeError):
LightningCLI(BoringModel)
def test_lightning_cli_save_config_only_once(cleandir):
config_path = "config.yaml"
cli_args = ["--trainer.logger=false", "--trainer.max_epochs=1"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(BoringModel, run=False)
save_config_callback = next(c for c in cli.trainer.callbacks if isinstance(c, SaveConfigCallback))
assert not save_config_callback.overwrite
assert not save_config_callback.already_saved
cli.trainer.fit(cli.model)
assert os.path.isfile(config_path)
assert save_config_callback.already_saved
cli.trainer.test(cli.model) # Should not fail because config already saved
def test_lightning_cli_save_config_seed_everything(cleandir):
config_path = Path("config.yaml")
cli_args = ["fit", "--seed_everything=true", "--trainer.logger=false", "--trainer.max_epochs=1"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(BoringModel)
config = yaml.safe_load(config_path.read_text())
assert isinstance(config["seed_everything"], int)
assert config["seed_everything"] == cli.config.fit.seed_everything
cli_args = ["--seed_everything=true", "--trainer.logger=false"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(BoringModel, run=False)
config = yaml.safe_load(config_path.read_text())
assert isinstance(config["seed_everything"], int)
assert config["seed_everything"] == cli.config.seed_everything
def test_save_to_log_dir_false_error():
with pytest.raises(ValueError):
SaveConfigCallback(
LightningArgumentParser(),
Namespace(),
save_to_log_dir=False,
)
def test_lightning_cli_logger_save_config(cleandir):
class LoggerSaveConfigCallback(SaveConfigCallback):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, save_to_log_dir=False, **kwargs)
def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
nonlocal config
config = self.parser.dump(self.config)
trainer.logger.log_hyperparams({"config": config})
config = None
cli_args = [
"fit",
"--trainer.max_epochs=1",
"--trainer.logger=TensorBoardLogger",
f"--trainer.logger.save_dir={os.getcwd()}",
]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(
BoringModel,
save_config_callback=LoggerSaveConfigCallback,
)
assert os.path.isdir(cli.trainer.log_dir)
assert not os.path.isfile(os.path.join(cli.trainer.log_dir, "config.yaml"))
events_file = glob.glob(os.path.join(cli.trainer.log_dir, "events.out.tfevents.*"))
assert len(events_file) == 1
ea = event_accumulator.EventAccumulator(events_file[0])
ea.Reload()
data = ea._plugin_to_tag_to_content["hparams"]["_hparams_/session_start_info"]
hparam_data = HParamsPluginData.FromString(data).session_start_info.hparams
assert hparam_data.get("config") is not None
assert hparam_data["config"].string_value == config
def test_lightning_cli_config_and_subclass_mode(cleandir):
input_config = {
"fit": {
"model": {"class_path": "lightning.pytorch.demos.boring_classes.BoringModel"},
"data": {
"class_path": "DataDirDataModule",
"init_args": {"data_dir": "."},
},
"trainer": {"max_epochs": 1, "enable_model_summary": False, "logger": False},
}
}
config_path = "config.yaml"
with open(config_path, "w") as f:
f.write(yaml.dump(input_config))
with mock.patch("sys.argv", ["any.py", "--config", config_path]), mock_subclasses(
LightningDataModule, DataDirDataModule
):
cli = LightningCLI(
BoringModel,
BoringDataModule,
subclass_mode_model=True,
subclass_mode_data=True,
save_config_kwargs={"overwrite": True},
)
config_path = "config.yaml"
assert os.path.isfile(config_path)
with open(config_path) as f:
loaded_config = yaml.safe_load(f.read())
cli_config = cli.config["fit"].as_dict()
assert loaded_config["model"] == cli_config["model"]
assert loaded_config["data"] == cli_config["data"]
assert loaded_config["trainer"] == cli_config["trainer"]
def any_model_any_data_cli():
LightningCLI(LightningModule, LightningDataModule, subclass_mode_model=True, subclass_mode_data=True)
@pytest.mark.skipif(compare_version("jsonargparse", operator.lt, "4.21.3"), reason="vulnerability with failing imports")
@pytest.mark.skipif(
(sys.version_info.major, sys.version_info.minor) == (3, 9)
and compare_version("jsonargparse", operator.lt, "4.24.0"),
reason="--trainer.precision is not parsed",
)
def test_lightning_cli_help():
cli_args = ["any.py", "fit", "--help"]
out = StringIO()
with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit):
any_model_any_data_cli()
out = out.getvalue()
assert "--print_config" in out
assert "--config" in out
assert "--seed_everything" in out
assert "--model.help" in out
assert "--data.help" in out
skip_params = {"self"}
for param in inspect.signature(Trainer.__init__).parameters:
if param not in skip_params:
assert f"--trainer.{param}" in out
cli_args = ["any.py", "fit", "--data.help=DataDirDataModule"]
out = StringIO()
with mock.patch("sys.argv", cli_args), redirect_stdout(out), mock_subclasses(
LightningDataModule, DataDirDataModule
), pytest.raises(SystemExit):
any_model_any_data_cli()
assert "--data.init_args.data_dir" in out.getvalue()
def test_lightning_cli_print_config():
cli_args = [
"any.py",
"predict",
"--seed_everything=1234",
"--model=lightning.pytorch.demos.boring_classes.BoringModel",
"--data=lightning.pytorch.demos.boring_classes.BoringDataModule",
"--print_config",
]
out = StringIO()
with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit):
any_model_any_data_cli()
text = out.getvalue()
# test dump_header
assert text.startswith(f"# lightning.pytorch=={__version__}")
outval = yaml.safe_load(text)
assert outval["seed_everything"] == 1234
assert outval["model"]["class_path"] == "lightning.pytorch.demos.boring_classes.BoringModel"
assert outval["data"]["class_path"] == "lightning.pytorch.demos.boring_classes.BoringDataModule"
assert outval["ckpt_path"] is None
def test_lightning_cli_submodules(cleandir):
class MainModule(BoringModel):
def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1):
super().__init__()
self.submodule1 = submodule1
self.submodule2 = submodule2
config = """model:
main_param: 2
submodule1:
class_path: lightning.pytorch.demos.boring_classes.BoringModel
submodule2:
class_path: lightning.pytorch.demos.boring_classes.BoringModel
"""
config_path = Path("config.yaml")
config_path.write_text(config)
cli_args = [f"--config={config_path}"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(MainModule, run=False)
assert cli.config["model"]["main_param"] == 2
assert isinstance(cli.model.submodule1, BoringModel)
assert isinstance(cli.model.submodule2, BoringModel)
@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason=str(_TORCHVISION_AVAILABLE))
def test_lightning_cli_torch_modules(cleandir):
class TestModule(BoringModel):
def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None):
super().__init__()
self.activation = activation
self.transform = transform
config = """model:
activation:
class_path: torch.nn.LeakyReLU
init_args:
negative_slope: 0.2
transform:
- class_path: torchvision.transforms.Resize
init_args:
size: 64
- class_path: torchvision.transforms.CenterCrop
init_args:
size: 64
"""
config_path = Path("config.yaml")
config_path.write_text(config)
cli_args = [f"--config={config_path}"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(TestModule, run=False)
assert isinstance(cli.model.activation, torch.nn.LeakyReLU)
assert cli.model.activation.negative_slope == 0.2
assert len(cli.model.transform) == 2
assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform)
class BoringModelRequiredClasses(BoringModel):
def __init__(self, num_classes: int, batch_size: int = 8):
super().__init__()
self.num_classes = num_classes
self.batch_size = batch_size
class BoringDataModuleBatchSizeAndClasses(BoringDataModule):
def __init__(self, batch_size: int = 8):
super().__init__()
self.batch_size = batch_size
self.num_classes = 5 # only available after instantiation
def test_lightning_cli_link_arguments():
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("data.batch_size", "model.batch_size")
parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate")
cli_args = ["--data.batch_size=12"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, run=False)
assert cli.model.batch_size == 12
assert cli.model.num_classes == 5
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("data.batch_size", "model.init_args.batch_size")
parser.link_arguments("data.num_classes", "model.init_args.num_classes", apply_on="instantiate")
cli_args[-1] = "--model=tests_pytorch.test_cli.BoringModelRequiredClasses"
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(
BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, subclass_mode_model=True, run=False
)
assert cli.model.batch_size == 8
assert cli.model.num_classes == 5
class EarlyExitTestModel(BoringModel):
def on_fit_start(self):
raise MisconfigurationException("Error on fit start")
# mps not yet supported by distributed
@RunIf(skip_windows=True, mps=False)
@pytest.mark.parametrize("logger", [False, TensorBoardLogger(".")])
@pytest.mark.parametrize("strategy", ["ddp_spawn", "ddp"])
def test_cli_distributed_save_config_callback(cleandir, logger, strategy):
from torch.multiprocessing import ProcessRaisedException
with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises(
(MisconfigurationException, ProcessRaisedException), match=r"Error on fit start"
):
LightningCLI(
EarlyExitTestModel,
trainer_defaults={
"logger": logger,
"max_steps": 1,
"max_epochs": 1,
"strategy": strategy,
"accelerator": "auto",
"devices": 1,
},
)
if logger:
config_dir = Path("lightning_logs")
# no more version dirs should get created
assert os.listdir(config_dir) == ["version_0"]
config_path = config_dir / "version_0" / "config.yaml"
else:
config_path = "config.yaml"
assert os.path.isfile(config_path)
def test_cli_config_overwrite(cleandir):
trainer_defaults = {"max_steps": 1, "max_epochs": 1, "logger": False}
argv = ["any.py", "fit"]
with mock.patch("sys.argv", argv):
LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
with mock.patch("sys.argv", argv), pytest.raises(RuntimeError, match="Aborting to avoid overwriting"):
LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
with mock.patch("sys.argv", argv):
LightningCLI(BoringModel, save_config_kwargs={"overwrite": True}, trainer_defaults=trainer_defaults)
def test_cli_config_filename(tmpdir):
with mock.patch("sys.argv", ["any.py", "fit"]):
LightningCLI(
BoringModel,
trainer_defaults={"default_root_dir": str(tmpdir), "logger": False, "max_steps": 1, "max_epochs": 1},
save_config_kwargs={"config_filename": "name.yaml"},
)
assert os.path.isfile(tmpdir / "name.yaml")
@pytest.mark.parametrize("run", [False, True])
def test_lightning_cli_optimizer(run):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(torch.optim.Adam)
match = "BoringModel.configure_optimizers` will be overridden by " "`MyLightningCLI.configure_optimizers`"
argv = ["fit", "--trainer.fast_dev_run=1"] if run else []
with mock.patch("sys.argv", ["any.py"] + argv), pytest.warns(UserWarning, match=match):
cli = MyLightningCLI(BoringModel, run=run)
assert cli.model.configure_optimizers is not BoringModel.configure_optimizers
if not run:
optimizer = cli.model.configure_optimizers()
assert isinstance(optimizer, torch.optim.Adam)
else:
assert len(cli.trainer.optimizers) == 1
assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam)
assert len(cli.trainer.lr_scheduler_configs) == 0
def test_lightning_cli_optimizer_and_lr_scheduler():
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(torch.optim.Adam)
parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR)
cli_args = ["fit", "--trainer.fast_dev_run=1", "--lr_scheduler.gamma=0.8"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(BoringModel)
assert cli.model.configure_optimizers is not BoringModel.configure_optimizers
assert len(cli.trainer.optimizers) == 1
assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam)
assert len(cli.trainer.lr_scheduler_configs) == 1
assert isinstance(cli.trainer.lr_scheduler_configs[0].scheduler, torch.optim.lr_scheduler.ExponentialLR)
assert cli.trainer.lr_scheduler_configs[0].scheduler.gamma == 0.8
def test_cli_no_need_configure_optimizers(cleandir):
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def training_step(self, *_):
...
def train_dataloader(self):
...
# did not define `configure_optimizers`
from lightning.pytorch.trainer.configuration_validator import __verify_train_val_loop_configuration
with mock.patch("sys.argv", ["any.py", "fit", "--optimizer=Adam"]), mock.patch(
"lightning.pytorch.Trainer._run_stage"
) as run, mock.patch(
"lightning.pytorch.trainer.configuration_validator.__verify_train_val_loop_configuration",
wraps=__verify_train_val_loop_configuration,
) as verify:
cli = LightningCLI(BoringModel)
run.assert_called_once()
verify.assert_called_once_with(cli.trainer, cli.model)
def test_lightning_cli_optimizer_and_lr_scheduler_subclasses(cleandir):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args((torch.optim.SGD, torch.optim.Adam))
parser.add_lr_scheduler_args((torch.optim.lr_scheduler.StepLR, torch.optim.lr_scheduler.ExponentialLR))
optimizer_arg = {"class_path": "torch.optim.Adam", "init_args": {"lr": 0.01}}
lr_scheduler_arg = {"class_path": "torch.optim.lr_scheduler.StepLR", "init_args": {"step_size": 50}}
cli_args = [
"fit",
"--trainer.max_epochs=1",
f"--optimizer={json.dumps(optimizer_arg)}",
f"--lr_scheduler={json.dumps(lr_scheduler_arg)}",
]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(BoringModel)
assert len(cli.trainer.optimizers) == 1
assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam)
assert len(cli.trainer.lr_scheduler_configs) == 1
assert isinstance(cli.trainer.lr_scheduler_configs[0].scheduler, torch.optim.lr_scheduler.StepLR)
assert cli.trainer.lr_scheduler_configs[0].scheduler.step_size == 50
@pytest.mark.parametrize("use_generic_base_class", [False, True])
def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_generic_base_class):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(
(torch.optim.Optimizer,) if use_generic_base_class else torch.optim.Adam,
nested_key="optim1",
link_to="model.optim1",
)
parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2")
parser.add_lr_scheduler_args(
LRSchedulerTypeTuple if use_generic_base_class else torch.optim.lr_scheduler.ExponentialLR,
link_to="model.scheduler",
)
class TestModel(BoringModel):
def __init__(self, optim1: dict, optim2: dict, scheduler: dict):
super().__init__()
self.optim1 = instantiate_class(self.parameters(), optim1)
self.optim2 = instantiate_class(self.parameters(), optim2)
self.scheduler = instantiate_class(self.optim1, scheduler)
cli_args = ["fit", "--trainer.fast_dev_run=1"]
if use_generic_base_class:
cli_args += [
"--optim1",
"Adam",
"--optim1.weight_decay",
"0.001",
"--optim2=SGD",
"--optim2.lr=0.01",
"--lr_scheduler=ExponentialLR",
]
else:
cli_args += ["--optim2=SGD", "--optim2.lr=0.01"]
cli_args += ["--lr_scheduler.gamma=0.2"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(TestModel)
assert isinstance(cli.model.optim1, torch.optim.Adam)
assert isinstance(cli.model.optim2, torch.optim.SGD)
assert cli.model.optim2.param_groups[0]["lr"] == 0.01
assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)
@pytest.mark.skipif(compare_version("jsonargparse", operator.lt, "4.21.3"), reason="vulnerability with failing imports")
def test_lightning_cli_optimizers_and_lr_scheduler_with_callable_type():
class TestModel(BoringModel):
def __init__(
self,
optim1: OptimizerCallable = torch.optim.Adam,
optim2: OptimizerCallable = torch.optim.Adagrad,
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__()
self.optim1 = optim1
self.optim2 = optim2
self.scheduler = scheduler
def configure_optimizers(self):
optim1 = self.optim1(self.parameters())
optim2 = self.optim2(self.parameters())
scheduler = self.scheduler(optim2)
return (
{"optimizer": optim1},
{"optimizer": optim2, "lr_scheduler": scheduler},
)
out = StringIO()
with mock.patch("sys.argv", ["any.py", "-h"]), redirect_stdout(out), pytest.raises(SystemExit):
LightningCLI(TestModel, run=False, auto_configure_optimizers=False)
out = out.getvalue()
assert "--optimizer" not in out
assert "--lr_scheduler" not in out
assert "--model.optim1" in out
assert "--model.optim2" in out
assert "--model.scheduler" in out
cli_args = [
"--model.optim1=Adagrad",
"--model.optim2=SGD",
"--model.optim2.lr=0.007",
"--model.scheduler=ExponentialLR",
"--model.scheduler.gamma=0.3",
]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(TestModel, run=False, auto_configure_optimizers=False)
init = cli.model.configure_optimizers()
assert isinstance(init[0]["optimizer"], torch.optim.Adagrad)
assert isinstance(init[1]["optimizer"], torch.optim.SGD)
assert isinstance(init[1]["lr_scheduler"], torch.optim.lr_scheduler.ExponentialLR)
assert init[1]["optimizer"].param_groups[0]["lr"] == 0.007
assert init[1]["lr_scheduler"].gamma == 0.3
@pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn])
def test_lightning_cli_trainer_fn(fn):
class TestCLI(LightningCLI):
def __init__(self, *args, **kwargs):
self.called = []
super().__init__(*args, **kwargs)
def before_fit(self):
self.called.append("before_fit")
def fit(self, **_):
self.called.append("fit")
def after_fit(self):
self.called.append("after_fit")
def before_validate(self):
self.called.append("before_validate")
def validate(self, **_):
self.called.append("validate")
def after_validate(self):
self.called.append("after_validate")
def before_test(self):
self.called.append("before_test")
def test(self, **_):
self.called.append("test")
def after_test(self):
self.called.append("after_test")
def before_predict(self):
self.called.append("before_predict")
def predict(self, **_):
self.called.append("predict")
def after_predict(self):
self.called.append("after_predict")
with mock.patch("sys.argv", ["any.py", fn]):
cli = TestCLI(BoringModel)
assert cli.called == [f"before_{fn}", fn, f"after_{fn}"]
def test_lightning_cli_subcommands():
subcommands = LightningCLI.subcommands()
trainer = Trainer()
for subcommand, exclude in subcommands.items():
fn = getattr(trainer, subcommand)
parameters = list(inspect.signature(fn).parameters)
for e in exclude:
# if this fails, it's because the parameter has been removed from the associated `Trainer` function
# and the `LightningCLI` subcommand exclusion list needs to be updated
assert e in parameters
@pytest.mark.skipif(compare_version("jsonargparse", operator.lt, "4.21.3"), reason="vulnerability with failing imports")
def test_lightning_cli_custom_subcommand():
class TestTrainer(Trainer):
def foo(self, model: LightningModule, x: int, y: float = 1.0):
"""Sample extra function.
Args:
model: A model
x: The x
y: The y
"""
class TestCLI(LightningCLI):
@staticmethod
def subcommands():
subcommands = LightningCLI.subcommands()
subcommands["foo"] = {"model"}
return subcommands
out = StringIO()
with mock.patch("sys.argv", ["any.py", "-h"]), redirect_stdout(out), pytest.raises(SystemExit):
TestCLI(BoringModel, trainer_class=TestTrainer)
out = out.getvalue()
assert "Sample extra function." in out
assert "{fit,validate,test,predict,foo}" in out
out = StringIO()
with mock.patch("sys.argv", ["any.py", "foo", "-h"]), redirect_stdout(out), pytest.raises(SystemExit):
TestCLI(BoringModel, trainer_class=TestTrainer)
out = out.getvalue()
assert "A model" not in out
assert "Sample extra function:" in out
assert "--x X" in out
assert "The x (required, type: int)" in out
assert "--y Y" in out
assert "The y (type: float, default: 1.0)" in out
def test_lightning_cli_run(cleandir):
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(BoringModel, run=False)
assert cli.trainer.global_step == 0
assert isinstance(cli.trainer, Trainer)
assert isinstance(cli.model, LightningModule)
with mock.patch("sys.argv", ["any.py", "fit"]):
cli = LightningCLI(BoringModel, trainer_defaults={"max_steps": 1, "max_epochs": 1})
assert cli.trainer.global_step == 1
assert isinstance(cli.trainer, Trainer)
assert isinstance(cli.model, LightningModule)
class TestModel(BoringModel):
def __init__(self, foo, bar=5):
super().__init__()
self.foo = foo
self.bar = bar
def test_lightning_cli_model_short_arguments():
with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch(
"lightning.pytorch.Trainer._fit_impl"
) as run, mock_subclasses(LightningModule, BoringModel, TestModel):
cli = LightningCLI(trainer_defaults={"fast_dev_run": 1})
assert isinstance(cli.model, BoringModel)
run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY)
with mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]), mock_subclasses(
LightningModule, BoringModel, TestModel
):
cli = LightningCLI(run=False)
assert isinstance(cli.model, TestModel)
assert cli.model.foo == 123
assert cli.model.bar == 5
class MyDataModule(BoringDataModule):
def __init__(self, foo, bar=5):
super().__init__()
self.foo = foo
self.bar = bar
def test_lightning_cli_datamodule_short_arguments():
# with set model
with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch(
"lightning.pytorch.Trainer._fit_impl"
) as run, mock_subclasses(LightningDataModule, BoringDataModule):
cli = LightningCLI(BoringModel, trainer_defaults={"fast_dev_run": 1})
assert isinstance(cli.datamodule, BoringDataModule)
run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule, ANY)
with mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]), mock_subclasses(
LightningDataModule, MyDataModule
):
cli = LightningCLI(BoringModel, run=False)
assert isinstance(cli.datamodule, MyDataModule)
assert cli.datamodule.foo == 123
assert cli.datamodule.bar == 5
# with configurable model
with mock.patch("sys.argv", ["any.py", "fit", "--model", "BoringModel", "--data=BoringDataModule"]), mock.patch(
"lightning.pytorch.Trainer._fit_impl"
) as run, mock_subclasses(LightningModule, BoringModel), mock_subclasses(LightningDataModule, BoringDataModule):
cli = LightningCLI(trainer_defaults={"fast_dev_run": 1})
assert isinstance(cli.model, BoringModel)
assert isinstance(cli.datamodule, BoringDataModule)
run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule, ANY)
with mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]), mock_subclasses(
LightningModule, BoringModel
), mock_subclasses(LightningDataModule, MyDataModule):
cli = LightningCLI(run=False)
assert isinstance(cli.model, BoringModel)
assert isinstance(cli.datamodule, MyDataModule)
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(BoringModel, run=False)
# data was not passed but we are adding it automatically because there are datamodules registered
assert "data" in cli.parser.groups
assert not hasattr(cli.parser.groups["data"], "group_class")
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(BoringModel, BoringDataModule, run=False)
# since we are passing the DataModule, that's whats added to the parser
assert cli.parser.groups["data"].group_class is BoringDataModule
@pytest.mark.parametrize("use_class_path_callbacks", [False, True])
def test_callbacks_append(use_class_path_callbacks):
"""This test validates registries are used when simplified command line are being used."""
cli_args = [
"--optimizer",
"Adam",
"--optimizer.lr",
"0.0001",
"--trainer.callbacks+=LearningRateMonitor",
"--trainer.callbacks.logging_interval=epoch",
"--trainer.callbacks.log_momentum=True",
"--model=BoringModel",
"--trainer.callbacks+",
"ModelCheckpoint",
"--trainer.callbacks.monitor=loss",
"--lr_scheduler",
"StepLR",
"--lr_scheduler.step_size=50",
]
extras = []
if use_class_path_callbacks:
callbacks = [
{"class_path": "lightning.pytorch.callbacks.Callback"},
{"class_path": "lightning.pytorch.callbacks.Callback", "init_args": {}},
]
cli_args += [f"--trainer.callbacks+={json.dumps(callbacks)}"]
extras = [Callback, Callback]
with mock.patch("sys.argv", ["any.py"] + cli_args), mock_subclasses(LightningModule, BoringModel):
cli = LightningCLI(run=False)
assert isinstance(cli.model, BoringModel)
optimizers, lr_scheduler = cli.model.configure_optimizers()
assert isinstance(optimizers[0], torch.optim.Adam)
assert optimizers[0].param_groups[0]["lr"] == 0.0001
assert lr_scheduler[0].step_size == 50
callback_types = [type(c) for c in cli.trainer.callbacks]
expected = [LearningRateMonitor, SaveConfigCallback, ModelCheckpoint] + extras
assert all(t in callback_types for t in expected)
def test_optimizers_and_lr_schedulers_reload(cleandir):
base = ["any.py", "--trainer.max_epochs=1"]
input = base + [
"--lr_scheduler",
"OneCycleLR",
"--lr_scheduler.total_steps=10",
"--lr_scheduler.max_lr=1",
"--optimizer",
"Adam",
"--optimizer.lr=0.1",
]
# save config
out = StringIO()
with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit):
LightningCLI(BoringModel, run=False)
# validate yaml
yaml_config = out.getvalue()
dict_config = yaml.safe_load(yaml_config)
assert dict_config["optimizer"]["class_path"] == "torch.optim.Adam"
assert dict_config["optimizer"]["init_args"]["lr"] == 0.1
assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR"
# reload config
yaml_config_file = Path("config.yaml")
yaml_config_file.write_text(yaml_config)
with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]):
LightningCLI(BoringModel, run=False)
def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload(cleandir):
class TestLightningCLI(LightningCLI):
def __init__(self, *args):
super().__init__(*args, run=False)
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(nested_key="opt1", link_to="model.opt1_config")
parser.add_optimizer_args(
(torch.optim.ASGD, torch.optim.SGD), nested_key="opt2", link_to="model.opt2_config"
)
parser.add_lr_scheduler_args(link_to="model.sch_config")
parser.add_argument("--something", type=str, nargs="+")
class TestModel(BoringModel):
def __init__(self, opt1_config: dict, opt2_config: dict, sch_config: dict):
super().__init__()
self.opt1_config = opt1_config
self.opt2_config = opt2_config
self.sch_config = sch_config
opt1 = instantiate_class(self.parameters(), opt1_config)
assert isinstance(opt1, torch.optim.Adam)
opt2 = instantiate_class(self.parameters(), opt2_config)
assert isinstance(opt2, torch.optim.ASGD)
sch = instantiate_class(opt1, sch_config)
assert isinstance(sch, torch.optim.lr_scheduler.OneCycleLR)
base = ["any.py", "--trainer.max_epochs=1"]
input = base + [
"--lr_scheduler",
"OneCycleLR",
"--lr_scheduler.total_steps=10",
"--lr_scheduler.max_lr=1",
"--opt1",
"Adam",
"--opt2=ASGD",
"--opt2.lr=0.1",
"--lr_scheduler.anneal_strategy=linear",
"--something",
"a",
"b",
"c",
]
# save config
out = StringIO()
with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit):
TestLightningCLI(TestModel)
# validate yaml
yaml_config = out.getvalue()
dict_config = yaml.safe_load(yaml_config)
assert dict_config["opt1"]["class_path"] == "torch.optim.Adam"
assert dict_config["opt2"]["class_path"] == "torch.optim.ASGD"
assert dict_config["opt2"]["init_args"]["lr"] == 0.1
assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR"
assert dict_config["lr_scheduler"]["init_args"]["anneal_strategy"] == "linear"
assert dict_config["something"] == ["a", "b", "c"]
# reload config
yaml_config_file = Path("config.yaml")
yaml_config_file.write_text(yaml_config)
with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]):
cli = TestLightningCLI(TestModel)
assert cli.model.opt1_config["class_path"] == "torch.optim.Adam"
assert cli.model.opt2_config["class_path"] == "torch.optim.ASGD"
assert cli.model.opt2_config["init_args"]["lr"] == 0.1
assert cli.model.sch_config["class_path"] == "torch.optim.lr_scheduler.OneCycleLR"
assert cli.model.sch_config["init_args"]["anneal_strategy"] == "linear"
def test_lightning_cli_config_with_subcommand():
config = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}}
with mock.patch("sys.argv", ["any.py", f"--config={config}"]), mock.patch(
"lightning.pytorch.Trainer.test", autospec=True
) as test_mock:
cli = LightningCLI(BoringModel)
test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar")
assert cli.trainer.limit_test_batches == 1
def test_lightning_cli_config_before_subcommand():
config = {
"validate": {"trainer": {"limit_val_batches": 1}, "verbose": False, "ckpt_path": "barfoo"},
"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"},
}
with mock.patch("sys.argv", ["any.py", f"--config={config}", "test"]), mock.patch(
"lightning.pytorch.Trainer.test", autospec=True
) as test_mock:
cli = LightningCLI(BoringModel)
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar")
assert cli.trainer.limit_test_batches == 1
save_config_callback = cli.trainer.callbacks[0]
assert save_config_callback.config.trainer.limit_test_batches == 1
assert save_config_callback.parser.subcommand == "test"
with mock.patch("sys.argv", ["any.py", f"--config={config}", "validate"]), mock.patch(
"lightning.pytorch.Trainer.validate", autospec=True
) as validate_mock:
cli = LightningCLI(BoringModel)
validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo")
assert cli.trainer.limit_val_batches == 1
save_config_callback = cli.trainer.callbacks[0]
assert save_config_callback.config.trainer.limit_val_batches == 1
assert save_config_callback.parser.subcommand == "validate"
def test_lightning_cli_config_before_subcommand_two_configs():
config1 = {"validate": {"trainer": {"limit_val_batches": 1}, "verbose": False, "ckpt_path": "barfoo"}}
config2 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}}
with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "test"]), mock.patch(
"lightning.pytorch.Trainer.test", autospec=True
) as test_mock:
cli = LightningCLI(BoringModel)
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar")
assert cli.trainer.limit_test_batches == 1
with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "validate"]), mock.patch(
"lightning.pytorch.Trainer.validate", autospec=True
) as validate_mock:
cli = LightningCLI(BoringModel)
validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo")
assert cli.trainer.limit_val_batches == 1
def test_lightning_cli_config_after_subcommand():
config = {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}
with mock.patch("sys.argv", ["any.py", "test", f"--config={config}"]), mock.patch(
"lightning.pytorch.Trainer.test", autospec=True
) as test_mock:
cli = LightningCLI(BoringModel)
test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar")
assert cli.trainer.limit_test_batches == 1
def test_lightning_cli_config_before_and_after_subcommand():
config1 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}}
config2 = {"trainer": {"fast_dev_run": 1}, "verbose": False, "ckpt_path": "foobar"}
with mock.patch("sys.argv", ["any.py", f"--config={config1}", "test", f"--config={config2}"]), mock.patch(
"lightning.pytorch.Trainer.test", autospec=True
) as test_mock:
cli = LightningCLI(BoringModel)
test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar")
assert cli.trainer.limit_test_batches == 1
assert cli.trainer.fast_dev_run == 1
def test_lightning_cli_parse_kwargs_with_subcommands(cleandir):
fit_config = {"trainer": {"limit_train_batches": 2}}
fit_config_path = Path("fit.yaml")
fit_config_path.write_text(str(fit_config))
validate_config = {"trainer": {"limit_val_batches": 3}}
validate_config_path = Path("validate.yaml")
validate_config_path.write_text(str(validate_config))
parser_kwargs = {
"fit": {"default_config_files": [str(fit_config_path)]},
"validate": {"default_config_files": [str(validate_config_path)]},
}
with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch(
"lightning.pytorch.Trainer.fit", autospec=True
) as fit_mock:
cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs)
fit_mock.assert_called()
assert cli.trainer.limit_train_batches == 2
assert cli.trainer.limit_val_batches == 1.0
with mock.patch("sys.argv", ["any.py", "validate"]), mock.patch(
"lightning.pytorch.Trainer.validate", autospec=True
) as validate_mock:
cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs)
validate_mock.assert_called()
assert cli.trainer.limit_train_batches == 1.0
assert cli.trainer.limit_val_batches == 3
def test_lightning_cli_subcommands_common_default_config_files(cleandir):
class Model(BoringModel):
def __init__(self, foo: int, *args, **kwargs):
super().__init__(*args, **kwargs)
self.foo = foo
config = {"fit": {"model": {"foo": 123}}}
config_path = Path("default.yaml")
config_path.write_text(str(config))
parser_kwargs = {"default_config_files": [str(config_path)]}
with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch(
"lightning.pytorch.Trainer.fit", autospec=True
) as fit_mock:
cli = LightningCLI(Model, parser_kwargs=parser_kwargs)
fit_mock.assert_called()
assert cli.model.foo == 123
def test_lightning_cli_reinstantiate_trainer():
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(BoringModel, run=False)
assert cli.trainer.max_epochs is None
class TestCallback(Callback):
...
# make sure a new trainer can be easily created
trainer = cli.instantiate_trainer(max_epochs=123, callbacks=[TestCallback()])
# the new config is used
assert trainer.max_epochs == 123
assert {c.__class__ for c in trainer.callbacks} == {c.__class__ for c in cli.trainer.callbacks}.union(
{TestCallback}
)
# the existing config is not updated
assert cli.config_init["trainer"]["max_epochs"] is None
def test_cli_configure_optimizers_warning():
match = "configure_optimizers` will be overridden by `LightningCLI"
with mock.patch("sys.argv", ["any.py"]), no_warning_call(UserWarning, match=match):
LightningCLI(BoringModel, run=False)
with mock.patch("sys.argv", ["any.py", "--optimizer=Adam"]), pytest.warns(UserWarning, match=match):
LightningCLI(BoringModel, run=False)
def test_cli_help_message():
# full class path
cli_args = ["any.py", "--optimizer.help=torch.optim.Adam"]
classpath_help = StringIO()
with mock.patch("sys.argv", cli_args), redirect_stdout(classpath_help), pytest.raises(SystemExit):
LightningCLI(BoringModel, run=False)
cli_args = ["any.py", "--optimizer.help=Adam"]
shorthand_help = StringIO()
with mock.patch("sys.argv", cli_args), redirect_stdout(shorthand_help), pytest.raises(SystemExit):
LightningCLI(BoringModel, run=False)
# the help messages should match
assert shorthand_help.getvalue() == classpath_help.getvalue()
# make sure it's not empty
assert "Implements Adam" in shorthand_help.getvalue()
def test_cli_reducelronplateau():
with mock.patch(
"sys.argv", ["any.py", "--optimizer=Adam", "--lr_scheduler=ReduceLROnPlateau", "--lr_scheduler.monitor=foo"]
):
cli = LightningCLI(BoringModel, run=False)
config = cli.model.configure_optimizers()
assert isinstance(config["lr_scheduler"]["scheduler"], ReduceLROnPlateau)
assert config["lr_scheduler"]["scheduler"].monitor == "foo"
def test_cli_configureoptimizers_can_be_overridden():
class MyCLI(LightningCLI):
def __init__(self):
super().__init__(BoringModel, run=False)
@staticmethod
def configure_optimizers(self, optimizer, lr_scheduler=None):
assert isinstance(self, BoringModel)
assert lr_scheduler is None
return 123
with mock.patch("sys.argv", ["any.py", "--optimizer=Adam"]):
cli = MyCLI()
assert cli.model.configure_optimizers() == 123
# with no optimization config, we don't override
with mock.patch("sys.argv", ["any.py"]):
cli = MyCLI()
[optimizer], [scheduler] = cli.model.configure_optimizers()
assert isinstance(optimizer, SGD)
assert isinstance(scheduler, StepLR)
with mock.patch("sys.argv", ["any.py", "--lr_scheduler=StepLR", "--lr_scheduler.step_size=50"]):
cli = MyCLI()
[optimizer], [scheduler] = cli.model.configure_optimizers()
assert isinstance(optimizer, SGD)
assert isinstance(scheduler, StepLR)
def test_cli_parameter_with_lazy_instance_default():
class TestModel(BoringModel):
def __init__(self, activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReLU, negative_slope=0.05)):
super().__init__()
self.activation = activation
model = TestModel()
assert isinstance(model.activation, torch.nn.LeakyReLU)
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(TestModel, run=False)
assert isinstance(cli.model.activation, torch.nn.LeakyReLU)
assert cli.model.activation.negative_slope == 0.05
assert cli.model.activation is not model.activation
def test_ddpstrategy_instantiation_and_find_unused_parameters(mps_count_0):
strategy_default = lazy_instance(DDPStrategy, find_unused_parameters=True)
with mock.patch("sys.argv", ["any.py", "--trainer.strategy.process_group_backend=group"]):
cli = LightningCLI(
BoringModel,
run=False,
trainer_defaults={"strategy": strategy_default},
)
assert cli.config.trainer.strategy.init_args.find_unused_parameters is True
assert isinstance(cli.config_init.trainer.strategy, DDPStrategy)
assert cli.config_init.trainer.strategy.process_group_backend == "group"
assert strategy_default is not cli.config_init.trainer.strategy
def test_cli_logger_shorthand():
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(TestModel, run=False, trainer_defaults={"logger": False})
assert cli.trainer.logger is None
with mock.patch("sys.argv", ["any.py", "--trainer.logger=TensorBoardLogger", "--trainer.logger.save_dir=foo"]):
cli = LightningCLI(TestModel, run=False, trainer_defaults={"logger": False})
assert isinstance(cli.trainer.logger, TensorBoardLogger)
with mock.patch("sys.argv", ["any.py", "--trainer.logger=False"]):
cli = LightningCLI(TestModel, run=False)
assert cli.trainer.logger is None
def _test_logger_init_args(logger_name, init, unresolved={}):
cli_args = [f"--trainer.logger={logger_name}"]
cli_args += [f"--trainer.logger.{k}={v}" for k, v in init.items()]
cli_args += [f"--trainer.logger.dict_kwargs.{k}={v}" for k, v in unresolved.items()]
cli_args.append("--print_config")
out = StringIO()
with mock.patch("sys.argv", ["any.py"] + cli_args), redirect_stdout(out), pytest.raises(SystemExit):
LightningCLI(TestModel, run=False)
data = yaml.safe_load(out.getvalue())["trainer"]["logger"]
assert {k: data["init_args"][k] for k in init} == init
if unresolved:
assert data["dict_kwargs"] == unresolved
@pytest.mark.skipif(not _COMET_AVAILABLE, reason="comet-ml is required")
def test_comet_logger_init_args():
_test_logger_init_args(
"CometLogger",
{
"save_dir": "comet", # Resolve from CometLogger.__init__
"workspace": "comet", # Resolve from Comet{,Existing,Offline}Experiment.__init__
},
)
@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="neptune is required")
def test_neptune_logger_init_args():
_test_logger_init_args(
"NeptuneLogger",
{
"name": "neptune", # Resolve from NeptuneLogger.__init__
},
{
"description": "neptune", # Unsupported resolving from neptune.internal.init.run.init_run
},
)
def test_tensorboard_logger_init_args():
_test_logger_init_args(
"TensorBoardLogger",
{
"save_dir": "tb", # Resolve from TensorBoardLogger.__init__
},
{
"comment": "tb", # Unsupported resolving from local imports
},
)
@pytest.mark.skipif(not _WANDB_AVAILABLE, reason="wandb is required")
def test_wandb_logger_init_args():
_test_logger_init_args(
"WandbLogger",
{
"save_dir": "wandb", # Resolve from WandbLogger.__init__
"notes": "wandb", # Resolve from wandb.sdk.wandb_init.init
},
)
def test_cli_auto_seeding():
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(TestModel, run=False, seed_everything_default=False)
assert cli.seed_everything_default is False
assert cli.config["seed_everything"] is False
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(TestModel, run=False, seed_everything_default=True)
assert cli.seed_everything_default is True
assert isinstance(cli.config["seed_everything"], int)
with mock.patch("sys.argv", ["any.py", "--seed_everything", "3"]):
cli = LightningCLI(TestModel, run=False, seed_everything_default=False)
assert cli.seed_everything_default is False
assert cli.config["seed_everything"] == 3
with mock.patch("sys.argv", ["any.py", "--seed_everything", "3"]):
cli = LightningCLI(TestModel, run=False, seed_everything_default=True)
assert cli.seed_everything_default is True
assert cli.config["seed_everything"] == 3
with mock.patch("sys.argv", ["any.py", "--seed_everything", "3"]):
cli = LightningCLI(TestModel, run=False, seed_everything_default=10)
assert cli.seed_everything_default == 10
assert cli.config["seed_everything"] == 3
with mock.patch("sys.argv", ["any.py", "--seed_everything", "false"]):
cli = LightningCLI(TestModel, run=False, seed_everything_default=10)
assert cli.seed_everything_default == 10
assert cli.config["seed_everything"] is False
with mock.patch("sys.argv", ["any.py", "--seed_everything", "false"]):
cli = LightningCLI(TestModel, run=False, seed_everything_default=True)
assert cli.seed_everything_default is True
assert cli.config["seed_everything"] is False
with mock.patch("sys.argv", ["any.py", "--seed_everything", "true"]):
cli = LightningCLI(TestModel, run=False, seed_everything_default=False)
assert cli.seed_everything_default is False
assert isinstance(cli.config["seed_everything"], int)
seed_everything(123)
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(TestModel, run=False)
assert cli.seed_everything_default is True
assert cli.config["seed_everything"] == 123 # the original seed is kept
def test_cli_trainer_no_callbacks():
class MyTrainer(Trainer):
def __init__(self):
super().__init__()
class MyCallback(Callback):
...
match = "MyTrainer` class does not expose the `callbacks"
with mock.patch("sys.argv", ["any.py"]), pytest.warns(UserWarning, match=match):
cli = LightningCLI(
BoringModel, run=False, trainer_class=MyTrainer, trainer_defaults={"callbacks": MyCallback()}
)
assert not any(isinstance(cb, MyCallback) for cb in cli.trainer.callbacks)
def test_unresolvable_import_paths():
class TestModel(BoringModel):
def __init__(self, a_func: Callable = torch.nn.Softmax):
super().__init__()
self.a_func = a_func
out = StringIO()
with mock.patch("sys.argv", ["any.py", "--print_config"]), redirect_stdout(out), pytest.raises(SystemExit):
LightningCLI(TestModel, run=False)
assert "a_func: torch.nn.Softmax" in out.getvalue()
def test_pytorch_profiler_init_args():
from lightning.pytorch.profilers import Profiler, PyTorchProfiler
init = {
"dirpath": "profiler", # Resolve from PyTorchProfiler.__init__
"row_limit": 10, # Resolve from PyTorchProfiler.__init__
"group_by_input_shapes": True, # Resolve from PyTorchProfiler.__init__
}
unresolved = {
"profile_memory": True, # Not possible to resolve parameters from dynamically chosen Type[_PROFILER]
"record_shapes": True, # Resolve from PyTorchProfiler.__init__, gets moved to init_args
}
cli_args = ["--trainer.profiler=PyTorchProfiler"]
cli_args += [f"--trainer.profiler.{k}={v}" for k, v in init.items()]
cli_args += [f"--trainer.profiler.dict_kwargs.{k}={v}" for k, v in unresolved.items()]
with mock.patch("sys.argv", ["any.py"] + cli_args), mock_subclasses(Profiler, PyTorchProfiler):
cli = LightningCLI(TestModel, run=False)
assert isinstance(cli.config_init.trainer.profiler, PyTorchProfiler)
init["record_shapes"] = unresolved.pop("record_shapes") # Test move to init_args
assert {k: cli.config.trainer.profiler.init_args[k] for k in init} == init
assert cli.config.trainer.profiler.dict_kwargs == unresolved
@pytest.mark.parametrize(
"args",
[
["--trainer.logger=False", "--model.foo=456"],
{"trainer": {"logger": False}, "model": {"foo": 456}},
Namespace(trainer=Namespace(logger=False), model=Namespace(foo=456)),
],
)
def test_lightning_cli_with_args_given(args):
with mock.patch("sys.argv", [""]):
cli = LightningCLI(TestModel, run=False, args=args)
assert isinstance(cli.model, TestModel)
assert cli.config.trainer.logger is False
assert cli.model.foo == 456
def test_lightning_cli_args_and_sys_argv_warning():
with mock.patch("sys.argv", ["", "--model.foo=456"]), pytest.warns(Warning, match="LightningCLI's args parameter "):
LightningCLI(TestModel, run=False, args=["--model.foo=789"])