# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import json import os import pickle import sys from argparse import Namespace from contextlib import redirect_stdout from io import StringIO from typing import List, Optional, Union from unittest import mock import pytest import torch import yaml from packaging import version from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.cli import ( CALLBACK_REGISTRY, instantiate_class, LightningArgumentParser, LightningCLI, LR_SCHEDULER_REGISTRY, OPTIMIZER_REGISTRY, SaveConfigCallback, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf torchvision_version = version.parse("0") if _TORCHVISION_AVAILABLE: torchvision_version = version.parse(__import__("torchvision").__version__) @mock.patch("argparse.ArgumentParser.parse_args") def test_default_args(mock_argparse, tmpdir): """Tests default argument parser for Trainer.""" mock_argparse.return_value = Namespace(**Trainer.default_attributes()) parser = LightningArgumentParser(add_help=False, parse_as_dict=False) args = parser.parse_args([]) args.max_epochs = 5 trainer = Trainer.from_argparse_args(args) assert isinstance(trainer, Trainer) assert trainer.max_epochs == 5 @pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], ["--weights_save_path=./"], []]) def test_add_argparse_args_redefined(cli_args): """Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness.""" parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_lightning_class_args(Trainer, None) args = parser.parse_args(cli_args) # make sure we can pickle args pickle.dumps(args) # Check few deprecated args are not in namespace: for depr_name in ("gradient_clip", "nb_gpu_nodes", "max_nb_epochs"): assert depr_name not in args trainer = Trainer.from_argparse_args(args=args) pickle.dumps(trainer) assert isinstance(trainer, Trainer) @pytest.mark.parametrize("cli_args", [["--callbacks=1", "--logger"], ["--foo", "--bar=1"]]) def test_add_argparse_args_redefined_error(cli_args, monkeypatch): """Asserts error raised in case of passing not default cli arguments.""" class _UnkArgError(Exception): pass def _raise(): raise _UnkArgError parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_lightning_class_args(Trainer, None) monkeypatch.setattr(parser, "exit", lambda *args: _raise(), raising=True) with pytest.raises(_UnkArgError): parser.parse_args(cli_args) @pytest.mark.parametrize( ["cli_args", "expected"], [ ("--auto_lr_find=True --auto_scale_batch_size=power", dict(auto_lr_find=True, auto_scale_batch_size="power")), ( "--auto_lr_find any_string --auto_scale_batch_size ON", dict(auto_lr_find="any_string", auto_scale_batch_size=True), ), ("--auto_lr_find=Yes --auto_scale_batch_size=On", dict(auto_lr_find=True, auto_scale_batch_size=True)), ("--auto_lr_find Off --auto_scale_batch_size No", dict(auto_lr_find=False, auto_scale_batch_size=False)), ("--auto_lr_find TRUE --auto_scale_batch_size FALSE", dict(auto_lr_find=True, auto_scale_batch_size=False)), ("--tpu_cores=8", dict(tpu_cores=8)), ("--tpu_cores=1,", dict(tpu_cores="1,")), ("--limit_train_batches=100", dict(limit_train_batches=100)), ("--limit_train_batches 0.8", dict(limit_train_batches=0.8)), ("--weights_summary=null", dict(weights_summary=None)), ( "", dict( # These parameters are marked as Optional[...] in Trainer.__init__, # with None as default. They should not be changed by the argparse # interface. min_steps=None, max_steps=None, log_gpu_memory=None, distributed_backend=None, weights_save_path=None, resume_from_checkpoint=None, profiler=None, ), ), ], ) def test_parse_args_parsing(cli_args, expected): """Test parsing simple types and None optionals not modified.""" cli_args = cli_args.split(" ") if cli_args else [] with mock.patch("sys.argv", ["any.py"] + cli_args): parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_lightning_class_args(Trainer, None) args = parser.parse_args() for k, v in expected.items(): assert getattr(args, k) == v if "tpu_cores" not in expected or _TPU_AVAILABLE: assert Trainer.from_argparse_args(args) @pytest.mark.parametrize( ["cli_args", "expected", "instantiate"], [ (["--gpus", "[0, 2]"], dict(gpus=[0, 2]), False), (["--tpu_cores=[1,3]"], dict(tpu_cores=[1, 3]), False), (['--accumulate_grad_batches={"5":3,"10":20}'], dict(accumulate_grad_batches={5: 3, 10: 20}), True), ], ) def test_parse_args_parsing_complex_types(cli_args, expected, instantiate): """Test parsing complex types.""" with mock.patch("sys.argv", ["any.py"] + cli_args): parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_lightning_class_args(Trainer, None) args = parser.parse_args() for k, v in expected.items(): assert getattr(args, k) == v if instantiate: assert Trainer.from_argparse_args(args) @pytest.mark.parametrize(["cli_args", "expected_gpu"], [("--gpus 1", [0]), ("--gpus 0,", [0]), ("--gpus 0,1", [0, 1])]) def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu): """Test parsing of gpus and instantiation of Trainer.""" monkeypatch.setattr("torch.cuda.device_count", lambda: 2) cli_args = cli_args.split(" ") if cli_args else [] with mock.patch("sys.argv", ["any.py"] + cli_args): parser = LightningArgumentParser(add_help=False, parse_as_dict=False) parser.add_lightning_class_args(Trainer, None) args = parser.parse_args() trainer = Trainer.from_argparse_args(args) assert trainer.data_parallel_device_ids == expected_gpu @pytest.mark.skipif( sys.version_info < (3, 7), reason="signature inspection while mocking is not working in Python < 3.7 despite autospec", ) @pytest.mark.parametrize( ["cli_args", "extra_args"], [ ({}, {}), (dict(logger=False), {}), (dict(logger=False), dict(logger=True)), (dict(logger=False), dict(checkpoint_callback=True)), ], ) def test_init_from_argparse_args(cli_args, extra_args): unknown_args = dict(unknown_arg=0) # unkown args in the argparser/namespace should be ignored with mock.patch("pytorch_lightning.Trainer.__init__", autospec=True, return_value=None) as init: trainer = Trainer.from_argparse_args(Namespace(**cli_args, **unknown_args), **extra_args) expected = dict(cli_args) expected.update(extra_args) # extra args should override any cli arg init.assert_called_with(trainer, **expected) # passing in unknown manual args should throw an error with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'unknown_arg'"): Trainer.from_argparse_args(Namespace(**cli_args), **extra_args, **unknown_args) class Model(LightningModule): def __init__(self, model_param: int): super().__init__() self.model_param = model_param def _model_builder(model_param: int) -> Model: return Model(model_param) def _trainer_builder( limit_train_batches: int, fast_dev_run: bool = False, callbacks: Optional[Union[List[Callback], Callback]] = None ) -> Trainer: return Trainer(limit_train_batches=limit_train_batches, fast_dev_run=fast_dev_run, callbacks=callbacks) @pytest.mark.parametrize(["trainer_class", "model_class"], [(Trainer, Model), (_trainer_builder, _model_builder)]) def test_lightning_cli(trainer_class, model_class, monkeypatch): """Test that LightningCLI correctly instantiates model, trainer and calls fit.""" expected_model = dict(model_param=7) expected_trainer = dict(limit_train_batches=100) def fit(trainer, model): for k, v in expected_model.items(): assert getattr(model, k) == v for k, v in expected_trainer.items(): assert getattr(trainer, k) == v save_callback = [x for x in trainer.callbacks if isinstance(x, SaveConfigCallback)] assert len(save_callback) == 1 save_callback[0].on_train_start(trainer, model) def on_train_start(callback, trainer, _): config_dump = callback.parser.dump(callback.config, skip_none=False) for k, v in expected_model.items(): assert f" {k}: {v}" in config_dump for k, v in expected_trainer.items(): assert f" {k}: {v}" in config_dump trainer.ran_asserts = True monkeypatch.setattr(Trainer, "fit", fit) monkeypatch.setattr(SaveConfigCallback, "on_train_start", on_train_start) with mock.patch("sys.argv", ["any.py", "fit", "--model.model_param=7", "--trainer.limit_train_batches=100"]): cli = LightningCLI(model_class, trainer_class=trainer_class, save_config_callback=SaveConfigCallback) assert hasattr(cli.trainer, "ran_asserts") and cli.trainer.ran_asserts def test_lightning_cli_args_callbacks(tmpdir): callbacks = [ dict( class_path="pytorch_lightning.callbacks.LearningRateMonitor", init_args=dict(logging_interval="epoch", log_momentum=True), ), dict(class_path="pytorch_lightning.callbacks.ModelCheckpoint", init_args=dict(monitor="NAME")), ] class TestModel(BoringModel): def on_fit_start(self): callback = [c for c in self.trainer.callbacks if isinstance(c, LearningRateMonitor)] assert len(callback) == 1 assert callback[0].logging_interval == "epoch" assert callback[0].log_momentum is True callback = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] assert len(callback) == 1 assert callback[0].monitor == "NAME" self.trainer.ran_asserts = True with mock.patch("sys.argv", ["any.py", "fit", f"--trainer.callbacks={json.dumps(callbacks)}"]): cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) assert cli.trainer.ran_asserts @pytest.mark.parametrize("run", (False, True)) def test_lightning_cli_configurable_callbacks(tmpdir, run): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_lightning_class_args(LearningRateMonitor, "learning_rate_monitor") def fit(self, **_): pass cli_args = ["fit"] if run else [] cli_args += [f"--trainer.default_root_dir={tmpdir}", "--learning_rate_monitor.logging_interval=epoch"] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(BoringModel, run=run) callback = [c for c in cli.trainer.callbacks if isinstance(c, LearningRateMonitor)] assert len(callback) == 1 assert callback[0].logging_interval == "epoch" def test_lightning_cli_args_cluster_environments(tmpdir): plugins = [dict(class_path="pytorch_lightning.plugins.environments.SLURMEnvironment")] class TestModel(BoringModel): def on_fit_start(self): # Ensure SLURMEnvironment is set, instead of default LightningEnvironment assert isinstance(self.trainer.accelerator_connector._cluster_environment, SLURMEnvironment) self.trainer.ran_asserts = True with mock.patch("sys.argv", ["any.py", "fit", f"--trainer.plugins={json.dumps(plugins)}"]): cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) assert cli.trainer.ran_asserts def test_lightning_cli_args(tmpdir): cli_args = [ "fit", f"--data.data_dir={tmpdir}", f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", "--trainer.weights_summary=null", "--seed_everything=1234", ] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = LightningCLI(BoringModel, BoringDataModule, trainer_defaults={"callbacks": [LearningRateMonitor()]}) config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml" assert os.path.isfile(config_path) with open(config_path) as f: loaded_config = yaml.safe_load(f.read()) loaded_config = loaded_config["fit"] cli_config = cli.config["fit"] assert cli_config["seed_everything"] == 1234 assert "model" not in loaded_config and "model" not in cli_config # no arguments to include assert loaded_config["data"] == cli_config["data"] assert loaded_config["trainer"] == cli_config["trainer"] def test_lightning_cli_save_config_cases(tmpdir): config_path = tmpdir / "config.yaml" cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.logger=False", "--trainer.fast_dev_run=1"] # With fast_dev_run!=False config should not be saved with mock.patch("sys.argv", ["any.py"] + cli_args): LightningCLI(BoringModel) assert not os.path.isfile(config_path) # With fast_dev_run==False config should be saved cli_args[-1] = "--trainer.max_epochs=1" with mock.patch("sys.argv", ["any.py"] + cli_args): LightningCLI(BoringModel) assert os.path.isfile(config_path) # If run again on same directory exception should be raised since config file already exists with mock.patch("sys.argv", ["any.py"] + cli_args), pytest.raises(RuntimeError): LightningCLI(BoringModel) def test_lightning_cli_config_and_subclass_mode(tmpdir): input_config = { "fit": { "model": {"class_path": "tests.helpers.BoringModel"}, "data": {"class_path": "tests.helpers.BoringDataModule", "init_args": {"data_dir": str(tmpdir)}}, "trainer": {"default_root_dir": str(tmpdir), "max_epochs": 1, "weights_summary": None}, } } config_path = tmpdir / "config.yaml" with open(config_path, "w") as f: f.write(yaml.dump(input_config)) with mock.patch("sys.argv", ["any.py", "--config", str(config_path)]): cli = LightningCLI( BoringModel, BoringDataModule, subclass_mode_model=True, subclass_mode_data=True, trainer_defaults={"callbacks": LearningRateMonitor()}, ) config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml" assert os.path.isfile(config_path) with open(config_path) as f: loaded_config = yaml.safe_load(f.read()) loaded_config = loaded_config["fit"] cli_config = cli.config["fit"] assert loaded_config["model"] == cli_config["model"] assert loaded_config["data"] == cli_config["data"] assert loaded_config["trainer"] == cli_config["trainer"] def any_model_any_data_cli(): LightningCLI(LightningModule, LightningDataModule, subclass_mode_model=True, subclass_mode_data=True) def test_lightning_cli_help(): cli_args = ["any.py", "fit", "--help"] out = StringIO() with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): any_model_any_data_cli() out = out.getvalue() assert "--print_config" in out assert "--config" in out assert "--seed_everything" in out assert "--model.help" in out assert "--data.help" in out skip_params = {"self"} for param in inspect.signature(Trainer.__init__).parameters.keys(): if param not in skip_params: assert f"--trainer.{param}" in out cli_args = ["any.py", "fit", "--data.help=tests.helpers.BoringDataModule"] out = StringIO() with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): any_model_any_data_cli() assert "--data.init_args.data_dir" in out.getvalue() def test_lightning_cli_print_config(): cli_args = [ "any.py", "predict", "--seed_everything=1234", "--model=tests.helpers.BoringModel", "--data=tests.helpers.BoringDataModule", "--print_config", ] out = StringIO() with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): any_model_any_data_cli() outval = yaml.safe_load(out.getvalue()) assert outval["seed_everything"] == 1234 assert outval["model"]["class_path"] == "tests.helpers.BoringModel" assert outval["data"]["class_path"] == "tests.helpers.BoringDataModule" assert outval["ckpt_path"] is None def test_lightning_cli_submodules(tmpdir): class MainModule(BoringModel): def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1): super().__init__() self.submodule1 = submodule1 self.submodule2 = submodule2 config = """model: main_param: 2 submodule1: class_path: tests.helpers.BoringModel submodule2: class_path: tests.helpers.BoringModel """ config_path = tmpdir / "config.yaml" with open(config_path, "w") as f: f.write(config) cli_args = [f"--trainer.default_root_dir={tmpdir}", f"--config={str(config_path)}"] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = LightningCLI(MainModule, run=False) assert cli.config["model"]["main_param"] == 2 assert isinstance(cli.model.submodule1, BoringModel) assert isinstance(cli.model.submodule2, BoringModel) @pytest.mark.skipif(torchvision_version < version.parse("0.8.0"), reason="torchvision>=0.8.0 is required") def test_lightning_cli_torch_modules(tmpdir): class TestModule(BoringModel): def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None): super().__init__() self.activation = activation self.transform = transform config = """model: activation: class_path: torch.nn.LeakyReLU init_args: negative_slope: 0.2 transform: - class_path: torchvision.transforms.Resize init_args: size: 64 - class_path: torchvision.transforms.CenterCrop init_args: size: 64 """ config_path = tmpdir / "config.yaml" with open(config_path, "w") as f: f.write(config) cli_args = [f"--trainer.default_root_dir={tmpdir}", f"--config={str(config_path)}"] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = LightningCLI(TestModule, run=False) assert isinstance(cli.model.activation, torch.nn.LeakyReLU) assert cli.model.activation.negative_slope == 0.2 assert len(cli.model.transform) == 2 assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform) class BoringModelRequiredClasses(BoringModel): def __init__(self, num_classes: int, batch_size: int = 8): super().__init__() self.num_classes = num_classes self.batch_size = batch_size class BoringDataModuleBatchSizeAndClasses(BoringDataModule): def __init__(self, batch_size: int = 8): super().__init__() self.batch_size = batch_size self.num_classes = 5 # only available after instantiation def test_lightning_cli_link_arguments(tmpdir): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.link_arguments("data.batch_size", "model.batch_size") parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate") cli_args = [f"--trainer.default_root_dir={tmpdir}", "--data.batch_size=12"] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, run=False) assert cli.model.batch_size == 12 assert cli.model.num_classes == 5 class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.link_arguments("data.batch_size", "model.init_args.batch_size") parser.link_arguments("data.num_classes", "model.init_args.num_classes", apply_on="instantiate") cli_args[-1] = "--model=tests.utilities.test_cli.BoringModelRequiredClasses" with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI( BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, subclass_mode_model=True, run=False ) assert cli.model.batch_size == 8 assert cli.model.num_classes == 5 class EarlyExitTestModel(BoringModel): def on_fit_start(self): raise Exception("Error on fit start") @pytest.mark.parametrize("logger", (False, True)) @pytest.mark.parametrize( "trainer_kwargs", ( dict(accelerator="ddp_cpu"), dict(accelerator="ddp_cpu", plugins="ddp_find_unused_parameters_false"), pytest.param({"tpu_cores": 1}, marks=RunIf(tpu=True)), ), ) def test_cli_ddp_spawn_save_config_callback(tmpdir, logger, trainer_kwargs): with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises(Exception, match=r"Error on fit start"): LightningCLI( EarlyExitTestModel, trainer_defaults={ "default_root_dir": str(tmpdir), "logger": logger, "max_steps": 1, "max_epochs": 1, **trainer_kwargs, }, ) if logger: config_dir = tmpdir / "lightning_logs" # no more version dirs should get created assert os.listdir(config_dir) == ["version_0"] config_path = config_dir / "version_0" / "config.yaml" else: config_path = tmpdir / "config.yaml" assert os.path.isfile(config_path) def test_cli_config_overwrite(tmpdir): trainer_defaults = {"default_root_dir": str(tmpdir), "logger": False, "max_steps": 1, "max_epochs": 1} argv = ["any.py", "fit"] with mock.patch("sys.argv", argv): LightningCLI(BoringModel, trainer_defaults=trainer_defaults) with mock.patch("sys.argv", argv), pytest.raises(RuntimeError, match="Aborting to avoid overwriting"): LightningCLI(BoringModel, trainer_defaults=trainer_defaults) with mock.patch("sys.argv", argv): LightningCLI(BoringModel, save_config_overwrite=True, trainer_defaults=trainer_defaults) @pytest.mark.parametrize("run", (False, True)) def test_lightning_cli_optimizer(tmpdir, run): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_optimizer_args(torch.optim.Adam) match = ( "BoringModel.configure_optimizers` will be overridden by " "`MyLightningCLI.add_configure_optimizers_method_to_model`" ) argv = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.fast_dev_run=1"] if run else [] with mock.patch("sys.argv", ["any.py"] + argv), pytest.warns(UserWarning, match=match): cli = MyLightningCLI(BoringModel, run=run) assert cli.model.configure_optimizers is not BoringModel.configure_optimizers if not run: optimizer = cli.model.configure_optimizers() assert isinstance(optimizer, torch.optim.Adam) else: assert len(cli.trainer.optimizers) == 1 assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) assert len(cli.trainer.lr_schedulers) == 0 def test_lightning_cli_optimizer_and_lr_scheduler(tmpdir): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_optimizer_args(torch.optim.Adam) parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR) cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.fast_dev_run=1", "--lr_scheduler.gamma=0.8"] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(BoringModel) assert cli.model.configure_optimizers is not BoringModel.configure_optimizers assert len(cli.trainer.optimizers) == 1 assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) assert len(cli.trainer.lr_schedulers) == 1 assert isinstance(cli.trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.ExponentialLR) assert cli.trainer.lr_schedulers[0]["scheduler"].gamma == 0.8 def test_lightning_cli_optimizer_and_lr_scheduler_subclasses(tmpdir): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_optimizer_args((torch.optim.SGD, torch.optim.Adam)) parser.add_lr_scheduler_args((torch.optim.lr_scheduler.StepLR, torch.optim.lr_scheduler.ExponentialLR)) optimizer_arg = dict(class_path="torch.optim.Adam", init_args=dict(lr=0.01)) lr_scheduler_arg = dict(class_path="torch.optim.lr_scheduler.StepLR", init_args=dict(step_size=50)) cli_args = [ "fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", f"--optimizer={json.dumps(optimizer_arg)}", f"--lr_scheduler={json.dumps(lr_scheduler_arg)}", ] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(BoringModel) assert len(cli.trainer.optimizers) == 1 assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam) assert len(cli.trainer.lr_schedulers) == 1 assert isinstance(cli.trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.StepLR) assert cli.trainer.lr_schedulers[0]["scheduler"].step_size == 50 @pytest.mark.parametrize("use_registries", [False, True]) def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_registries, tmpdir): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_optimizer_args( OPTIMIZER_REGISTRY.classes if use_registries else torch.optim.Adam, nested_key="optim1", link_to="model.optim1", ) parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2") parser.add_lr_scheduler_args( LR_SCHEDULER_REGISTRY.classes if use_registries else torch.optim.lr_scheduler.ExponentialLR, link_to="model.scheduler", ) class TestModel(BoringModel): def __init__(self, optim1: dict, optim2: dict, scheduler: dict): super().__init__() self.optim1 = instantiate_class(self.parameters(), optim1) self.optim2 = instantiate_class(self.parameters(), optim2) self.scheduler = instantiate_class(self.optim1, scheduler) cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", "--lr_scheduler.gamma=0.2"] if use_registries: cli_args += [ "--optim1", "Adam", "--optim1.weight_decay", "0.001", "--optim2=SGD", "--optim2.lr=0.01", "--lr_scheduler=ExponentialLR", ] else: cli_args += ["--optim2.class_path=torch.optim.SGD", "--optim2.init_args.lr=0.01"] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(TestModel) assert isinstance(cli.model.optim1, torch.optim.Adam) assert isinstance(cli.model.optim2, torch.optim.SGD) assert cli.model.optim2.param_groups[0]["lr"] == 0.01 assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR) @pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn]) def test_lightning_cli_trainer_fn(fn): class TestCLI(LightningCLI): def __init__(self, *args, **kwargs): self.called = [] super().__init__(*args, **kwargs) def before_fit(self): self.called.append("before_fit") def fit(self, **_): self.called.append("fit") def after_fit(self): self.called.append("after_fit") def before_validate(self): self.called.append("before_validate") def validate(self, **_): self.called.append("validate") def after_validate(self): self.called.append("after_validate") def before_test(self): self.called.append("before_test") def test(self, **_): self.called.append("test") def after_test(self): self.called.append("after_test") def before_predict(self): self.called.append("before_predict") def predict(self, **_): self.called.append("predict") def after_predict(self): self.called.append("after_predict") def before_tune(self): self.called.append("before_tune") def tune(self, **_): self.called.append("tune") def after_tune(self): self.called.append("after_tune") with mock.patch("sys.argv", ["any.py", fn]): cli = TestCLI(BoringModel) assert cli.called == [f"before_{fn}", fn, f"after_{fn}"] def test_lightning_cli_subcommands(): subcommands = LightningCLI.subcommands() trainer = Trainer() for subcommand, exclude in subcommands.items(): fn = getattr(trainer, subcommand) parameters = list(inspect.signature(fn).parameters) for e in exclude: # if this fails, it's because the parameter has been removed from the associated `Trainer` function # and the `LightningCLI` subcommand exclusion list needs to be updated assert e in parameters def test_lightning_cli_custom_subcommand(): class TestTrainer(Trainer): def foo(self, model: LightningModule, x: int, y: float = 1.0): """Sample extra function. Args: model: A model x: The x y: The y """ class TestCLI(LightningCLI): @staticmethod def subcommands(): subcommands = LightningCLI.subcommands() subcommands["foo"] = {"model"} return subcommands out = StringIO() with mock.patch("sys.argv", ["any.py", "-h"]), redirect_stdout(out), pytest.raises(SystemExit): TestCLI(BoringModel, trainer_class=TestTrainer) out = out.getvalue() assert "Sample extra function." in out assert "{fit,validate,test,predict,tune,foo}" in out out = StringIO() with mock.patch("sys.argv", ["any.py", "foo", "-h"]), redirect_stdout(out), pytest.raises(SystemExit): TestCLI(BoringModel, trainer_class=TestTrainer) out = out.getvalue() assert "A model" not in out assert "Sample extra function:" in out assert "--x X" in out assert "The x (required, type: int)" in out assert "--y Y" in out assert "The y (type: float, default: 1.0)" in out def test_lightning_cli_run(): with mock.patch("sys.argv", ["any.py"]): cli = LightningCLI(BoringModel, run=False) assert cli.trainer.global_step == 0 assert isinstance(cli.trainer, Trainer) assert isinstance(cli.model, LightningModule) with mock.patch("sys.argv", ["any.py", "fit"]): cli = LightningCLI(BoringModel, trainer_defaults={"max_steps": 1, "max_epochs": 1}) assert cli.trainer.global_step == 1 assert isinstance(cli.trainer, Trainer) assert isinstance(cli.model, LightningModule) @OPTIMIZER_REGISTRY class CustomAdam(torch.optim.Adam): pass @LR_SCHEDULER_REGISTRY class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR): pass @CALLBACK_REGISTRY class CustomCallback(Callback): pass def test_registries(tmpdir): assert "SGD" in OPTIMIZER_REGISTRY.names assert "RMSprop" in OPTIMIZER_REGISTRY.names assert "CustomAdam" in OPTIMIZER_REGISTRY.names assert "CosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names assert "CosineAnnealingWarmRestarts" in LR_SCHEDULER_REGISTRY.names assert "CustomCosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names assert "EarlyStopping" in CALLBACK_REGISTRY.names assert "CustomCallback" in CALLBACK_REGISTRY.names with pytest.raises(MisconfigurationException, match="is already present in the registry"): OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer) OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer, override=True) # test `_Registry.__call__` returns the class assert isinstance(CustomCallback(), CustomCallback) @pytest.mark.parametrize("use_class_path_callbacks", [False, True]) def test_registries_resolution(use_class_path_callbacks): """This test validates registries are used when simplified command line are being used.""" cli_args = [ "--optimizer", "Adam", "--optimizer.lr", "0.0001", "--trainer.callbacks=LearningRateMonitor", "--trainer.callbacks.logging_interval=epoch", "--trainer.callbacks.log_momentum=True", "--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=loss", "--lr_scheduler", "StepLR", "--lr_scheduler.step_size=50", ] extras = [] if use_class_path_callbacks: callbacks = [ {"class_path": "pytorch_lightning.callbacks.Callback"}, {"class_path": "pytorch_lightning.callbacks.Callback", "init_args": {}}, ] cli_args += [f"--trainer.callbacks={json.dumps(callbacks)}"] extras = [Callback, Callback] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = LightningCLI(BoringModel, run=False) optimizers, lr_scheduler = cli.model.configure_optimizers() assert isinstance(optimizers[0], torch.optim.Adam) assert optimizers[0].param_groups[0]["lr"] == 0.0001 assert lr_scheduler[0].step_size == 50 callback_types = [type(c) for c in cli.trainer.callbacks] expected = [LearningRateMonitor, SaveConfigCallback, ModelCheckpoint] + extras assert all(t in callback_types for t in expected) def test_argv_transformation_noop(): base = ["any.py", "--trainer.max_epochs=1"] argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", base) assert argv == base def test_argv_transformation_single_callback(): base = ["any.py", "--trainer.max_epochs=1"] input = base + ["--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=val_loss"] callbacks = [ { "class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", "init_args": {"monitor": "val_loss"}, } ] expected = base + ["--trainer.callbacks", str(callbacks)] argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input) assert argv == expected def test_argv_transformation_multiple_callbacks(): base = ["any.py", "--trainer.max_epochs=1"] input = base + [ "--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=val_loss", "--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=val_acc", ] callbacks = [ { "class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", "init_args": {"monitor": "val_loss"}, }, { "class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", "init_args": {"monitor": "val_acc"}, }, ] expected = base + ["--trainer.callbacks", str(callbacks)] argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input) assert argv == expected def test_argv_transformation_multiple_callbacks_with_config(): base = ["any.py", "--trainer.max_epochs=1"] nested_key = "trainer.callbacks" input = base + [ f"--{nested_key}=ModelCheckpoint", f"--{nested_key}.monitor=val_loss", f"--{nested_key}=ModelCheckpoint", f"--{nested_key}.monitor=val_acc", f"--{nested_key}=[{{'class_path': 'pytorch_lightning.callbacks.Callback'}}]", ] callbacks = [ { "class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", "init_args": {"monitor": "val_loss"}, }, { "class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", "init_args": {"monitor": "val_acc"}, }, {"class_path": "pytorch_lightning.callbacks.Callback"}, ] expected = base + ["--trainer.callbacks", str(callbacks)] nested_key = "trainer.callbacks" argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, nested_key, input) assert argv == expected @pytest.mark.parametrize( ["args", "expected", "nested_key", "registry"], [ ( ["--optimizer", "Adadelta"], {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {}}, "optimizer", OPTIMIZER_REGISTRY, ), ( ["--optimizer", "Adadelta", "--optimizer.lr", "10"], {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {"lr": "10"}}, "optimizer", OPTIMIZER_REGISTRY, ), ( ["--lr_scheduler", "OneCycleLR"], {"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {}}, "lr_scheduler", LR_SCHEDULER_REGISTRY, ), ( ["--lr_scheduler", "OneCycleLR", "--lr_scheduler.anneal_strategy=linear"], {"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {"anneal_strategy": "linear"}}, "lr_scheduler", LR_SCHEDULER_REGISTRY, ), ], ) def test_argv_transformations_with_optimizers_and_lr_schedulers(args, expected, nested_key, registry): base = ["any.py", "--trainer.max_epochs=1"] argv = base + args new_argv = LightningArgumentParser._convert_argv_issue_84(registry.classes, nested_key, argv) assert new_argv == base + [f"--{nested_key}", str(expected)] def test_optimizers_and_lr_schedulers_reload(tmpdir): base = ["any.py", "--trainer.max_epochs=1"] input = base + [ "--lr_scheduler", "OneCycleLR", "--lr_scheduler.total_steps=10", "--lr_scheduler.max_lr=1", "--optimizer", "Adam", "--optimizer.lr=0.1", ] # save config out = StringIO() with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit): LightningCLI(BoringModel, run=False) # validate yaml yaml_config = out.getvalue() dict_config = yaml.safe_load(yaml_config) assert dict_config["optimizer"]["class_path"] == "torch.optim.adam.Adam" assert dict_config["optimizer"]["init_args"]["lr"] == 0.1 assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR" # reload config yaml_config_file = tmpdir / "config.yaml" yaml_config_file.write_text(yaml_config, "utf-8") with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]): LightningCLI(BoringModel, run=False) def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload(tmpdir): class TestLightningCLI(LightningCLI): def __init__(self, *args): super().__init__(*args, run=False) def add_arguments_to_parser(self, parser): parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes, nested_key="opt1", link_to="model.opt1_config") parser.add_optimizer_args( (torch.optim.ASGD, torch.optim.SGD), nested_key="opt2", link_to="model.opt2_config" ) parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes, link_to="model.sch_config") parser.add_argument("--something", type=str, nargs="+") class TestModel(BoringModel): def __init__(self, opt1_config: dict, opt2_config: dict, sch_config: dict): super().__init__() self.opt1_config = opt1_config self.opt2_config = opt2_config self.sch_config = sch_config opt1 = instantiate_class(self.parameters(), opt1_config) assert isinstance(opt1, torch.optim.Adam) opt2 = instantiate_class(self.parameters(), opt2_config) assert isinstance(opt2, torch.optim.ASGD) sch = instantiate_class(opt1, sch_config) assert isinstance(sch, torch.optim.lr_scheduler.OneCycleLR) base = ["any.py", "--trainer.max_epochs=1"] input = base + [ "--lr_scheduler", "OneCycleLR", "--lr_scheduler.total_steps=10", "--lr_scheduler.max_lr=1", "--opt1", "Adam", "--opt2.lr=0.1", "--opt2", "ASGD", "--lr_scheduler.anneal_strategy=linear", "--something", "a", "b", "c", ] # save config out = StringIO() with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit): TestLightningCLI(TestModel) # validate yaml yaml_config = out.getvalue() dict_config = yaml.safe_load(yaml_config) assert dict_config["opt1"]["class_path"] == "torch.optim.adam.Adam" assert dict_config["opt2"]["class_path"] == "torch.optim.asgd.ASGD" assert dict_config["opt2"]["init_args"]["lr"] == 0.1 assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR" assert dict_config["lr_scheduler"]["init_args"]["anneal_strategy"] == "linear" assert dict_config["something"] == ["a", "b", "c"] # reload config yaml_config_file = tmpdir / "config.yaml" yaml_config_file.write_text(yaml_config, "utf-8") with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]): cli = TestLightningCLI(TestModel) assert cli.model.opt1_config["class_path"] == "torch.optim.adam.Adam" assert cli.model.opt2_config["class_path"] == "torch.optim.asgd.ASGD" assert cli.model.opt2_config["init_args"]["lr"] == 0.1 assert cli.model.sch_config["class_path"] == "torch.optim.lr_scheduler.OneCycleLR" assert cli.model.sch_config["init_args"]["anneal_strategy"] == "linear" @RunIf(min_python="3.7.3") # bpo-17185: `autospec=True` and `inspect.signature` do not play well def test_lightning_cli_config_with_subcommand(): config = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}} with mock.patch("sys.argv", ["any.py", f"--config={config}"]), mock.patch( "pytorch_lightning.Trainer.test", autospec=True ) as test_mock: cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar") assert cli.trainer.limit_test_batches == 1 @RunIf(min_python="3.7.3") def test_lightning_cli_config_before_subcommand(): config = { "validate": {"trainer": {"limit_val_batches": 1}, "verbose": False, "ckpt_path": "barfoo"}, "test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}, } with mock.patch("sys.argv", ["any.py", f"--config={config}", "test"]), mock.patch( "pytorch_lightning.Trainer.test", autospec=True ) as test_mock: cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar") assert cli.trainer.limit_test_batches == 1 with mock.patch("sys.argv", ["any.py", f"--config={config}", "validate"]), mock.patch( "pytorch_lightning.Trainer.validate", autospec=True ) as validate_mock: cli = LightningCLI(BoringModel) validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo") assert cli.trainer.limit_val_batches == 1 @RunIf(min_python="3.7.3") def test_lightning_cli_config_before_subcommand_two_configs(): config1 = {"validate": {"trainer": {"limit_val_batches": 1}, "verbose": False, "ckpt_path": "barfoo"}} config2 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}} with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "test"]), mock.patch( "pytorch_lightning.Trainer.test", autospec=True ) as test_mock: cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=True, ckpt_path="foobar") assert cli.trainer.limit_test_batches == 1 with mock.patch("sys.argv", ["any.py", f"--config={config1}", f"--config={config2}", "validate"]), mock.patch( "pytorch_lightning.Trainer.validate", autospec=True ) as validate_mock: cli = LightningCLI(BoringModel) validate_mock.assert_called_once_with(cli.trainer, cli.model, verbose=False, ckpt_path="barfoo") assert cli.trainer.limit_val_batches == 1 @RunIf(min_python="3.7.3") def test_lightning_cli_config_after_subcommand(): config = {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"} with mock.patch("sys.argv", ["any.py", "test", f"--config={config}"]), mock.patch( "pytorch_lightning.Trainer.test", autospec=True ) as test_mock: cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, cli.model, verbose=True, ckpt_path="foobar") assert cli.trainer.limit_test_batches == 1 @RunIf(min_python="3.7.3") def test_lightning_cli_config_before_and_after_subcommand(): config1 = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}} config2 = {"trainer": {"fast_dev_run": 1}, "verbose": False, "ckpt_path": "foobar"} with mock.patch("sys.argv", ["any.py", f"--config={config1}", "test", f"--config={config2}"]), mock.patch( "pytorch_lightning.Trainer.test", autospec=True ) as test_mock: cli = LightningCLI(BoringModel) test_mock.assert_called_once_with(cli.trainer, model=cli.model, verbose=False, ckpt_path="foobar") assert cli.trainer.limit_test_batches == 1 assert cli.trainer.fast_dev_run == 1 def test_lightning_cli_parse_kwargs_with_subcommands(tmpdir): fit_config = {"trainer": {"limit_train_batches": 2}} fit_config_path = tmpdir / "fit.yaml" fit_config_path.write_text(str(fit_config), "utf8") validate_config = {"trainer": {"limit_val_batches": 3}} validate_config_path = tmpdir / "validate.yaml" validate_config_path.write_text(str(validate_config), "utf8") parser_kwargs = { "fit": {"default_config_files": [str(fit_config_path)]}, "validate": {"default_config_files": [str(validate_config_path)]}, } with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch( "pytorch_lightning.Trainer.fit", autospec=True ) as fit_mock: cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs) fit_mock.assert_called() assert cli.trainer.limit_train_batches == 2 assert cli.trainer.limit_val_batches == 1.0 with mock.patch("sys.argv", ["any.py", "validate"]), mock.patch( "pytorch_lightning.Trainer.validate", autospec=True ) as validate_mock: cli = LightningCLI(BoringModel, parser_kwargs=parser_kwargs) validate_mock.assert_called() assert cli.trainer.limit_train_batches == 1.0 assert cli.trainer.limit_val_batches == 3 def test_lightning_cli_reinstantiate_trainer(): with mock.patch("sys.argv", ["any.py"]): cli = LightningCLI(BoringModel, run=False) assert cli.trainer.max_epochs == 1000 class TestCallback(Callback): ... # make sure a new trainer can be easily created trainer = cli.instantiate_trainer(max_epochs=123, callbacks=[TestCallback()]) # the new config is used assert trainer.max_epochs == 123 assert {c.__class__ for c in trainer.callbacks} == {c.__class__ for c in cli.trainer.callbacks}.union( {TestCallback} ) # the existing config is not updated assert cli.config_init["trainer"]["max_epochs"] is None