From bbcb977851526cd97d25adf1d9a36b2955b53708 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 17 Sep 2021 19:00:46 +0200 Subject: [PATCH] [CLI] Shorthand notation to instantiate optimizers and lr schedulers [2/3] (#9565) --- CHANGELOG.md | 3 + docs/source/common/lightning_cli.rst | 155 ++++++++++++------ pytorch_lightning/utilities/cli.py | 120 +++++++++++++- requirements/extra.txt | 2 +- tests/utilities/test_cli.py | 230 +++++++++++++++++++++++++-- 5 files changed, 451 insertions(+), 59 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c039c8ba48..660f059ca3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `LightningCLI(run=False|True)` to choose whether to run a `Trainer` subcommand ([#8751](https://github.com/PyTorchLightning/pytorch-lightning/pull/8751)) * Added support to call any trainer function from the `LightningCLI` via subcommands ([#7508](https://github.com/PyTorchLightning/pytorch-lightning/pull/7508)) * Allow easy trainer re-instantiation ([#7508](https://github.com/PyTorchLightning/pytorch-lightning/pull/9241)) + * Automatically register all optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565)) + * Allow registering custom optimizers and learning rate schedulers without subclassing the CLI ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565)) + * Support shorthand notation to instantiate optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565)) - Fault-tolerant training: diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index b873c36168..e664e36aa0 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -665,27 +665,118 @@ Optimizers and learning rate schedulers ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Optimizers and learning rate schedulers can also be made configurable. The most common case is when a model only has a -single optimizer and optionally a single learning rate scheduler. In this case the model's -:class:`~pytorch_lightning.core.lightning.LightningModule` could be left without implementing the -:code:`configure_optimizers` method since it is normally always the same and just adds boilerplate. The following code -snippet shows how to implement it: +single optimizer and optionally a single learning rate scheduler. In this case, the model's +:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers` could be left unimplemented since it is +normally always the same and just adds boilerplate. + +The CLI works out-of-the-box with PyTorch's built-in optimizers and learning rate schedulers when +at most one of each is used. +Only the optimizer or scheduler name needs to be passed, optionally with its ``__init__`` arguments: + +.. code-block:: bash + + $ python trainer.py fit --optimizer=Adam --optimizer.lr=0.01 --lr_scheduler=ExponentialLR --lr_scheduler.gamma=0.1 + +A corresponding example of the config file would be: + +.. code-block:: yaml + + optimizer: + class_path: torch.optim.Adam + init_args: + lr: 0.01 + lr_scheduler: + class_path: torch.optim.lr_scheduler.ExponentialLR + init_args: + gamma: 0.1 + model: + ... + trainer: + ... + +.. note:: + + This short-hand notation is only supported in the shell and not inside a configuration file. The configuration file + generated by calling the previous command with ``--print_config`` will have the ``class_path`` notation. + +Furthermore, you can register your own optimizers and/or learning rate schedulers as follows: + +.. code-block:: python + + from pytorch_lightning.utilities.cli import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY + + + @OPTIMIZER_REGISTRY + class CustomAdam(torch.optim.Adam): + ... + + + @LR_SCHEDULER_REGISTRY + class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR): + ... + + + # register all `Optimizer` subclasses from the `torch.optim` package + # This is done automatically! + OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer) + + cli = LightningCLI(...) + +.. code-block:: bash + + $ python trainer.py fit --optimizer=CustomAdam --optimizer.lr=0.01 --lr_scheduler=CustomCosineAnnealingLR + +If you need to customize the key names or link arguments together, you can choose from all available optimizers and +learning rate schedulers by accessing the registries. + +.. code-block:: + + class MyLightningCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.add_optimizer_args( + OPTIMIZER_REGISTRY.classes, + nested_key="gen_optimizer", + link_to="model.optimizer1_init" + ) + parser.add_optimizer_args( + OPTIMIZER_REGISTRY.classes, + nested_key="gen_discriminator", + link_to="model.optimizer2_init" + ) + +.. code-block:: bash + + $ python trainer.py fit \ + --gen_optimizer=Adam \ + --gen_optimizer.lr=0.01 \ + --gen_discriminator=AdamW \ + --gen_discriminator.lr=0.0001 + +You can also use pass the class path directly, for example, if the optimizer hasn't been registered to the +``OPTIMIZER_REGISTRY``: + +.. code-block:: bash + + $ python trainer.py fit \ + --gen_optimizer.class_path=torch.optim.Adam \ + --gen_optimizer.init_args.lr=0.01 \ + --gen_discriminator.class_path=torch.optim.AdamW \ + --gen_discriminator.init_args.lr=0.0001 + +If you will not be changing the class, you can manually add the arguments for specific optimizers and/or +learning rate schedulers by subclassing the CLI. This has the advantage of providing the proper help message for those +classes. The following code snippet shows how to implement it: .. testcode:: - import torch - - class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_optimizer_args(torch.optim.Adam) parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR) - - cli = MyLightningCLI(MyModel) - -With this the :code:`configure_optimizers` method is automatically implemented and in the config the :code:`optimizer` -and :code:`lr_scheduler` groups would accept all of the options for the given classes, in this example :code:`Adam` and -:code:`ExponentialLR`. Therefore, the config file would be structured like: +With this, in the config the :code:`optimizer` and :code:`lr_scheduler` groups would accept all of the options for the +given classes, in this example :code:`Adam` and :code:`ExponentialLR`. +Therefore, the config file would be structured like: .. code-block:: yaml @@ -698,37 +789,12 @@ and :code:`lr_scheduler` groups would accept all of the options for the given cl trainer: ... -And any of these arguments could be passed directly through command line. For example: +Where the arguments can be passed directly through command line without specifying the class. For example: .. code-block:: bash $ python trainer.py fit --optimizer.lr=0.01 --lr_scheduler.gamma=0.2 -There is also the possibility of selecting among multiple classes by giving them as a tuple. For example: - -.. testcode:: - - class MyLightningCLI(LightningCLI): - def add_arguments_to_parser(self, parser): - parser.add_optimizer_args((torch.optim.SGD, torch.optim.Adam)) - -In this case in the config the :code:`optimizer` group instead of having directly init settings, it should specify -:code:`class_path` and optionally :code:`init_args`. Sub-classes of the classes in the tuple would also be accepted. -A corresponding example of the config file would be: - -.. code-block:: yaml - - optimizer: - class_path: torch.optim.Adam - init_args: - lr: 0.01 - -And the same through command line: - -.. code-block:: bash - - $ python trainer.py fit --optimizer.class_path=torch.optim.Adam --optimizer.init_args.lr=0.01 - The automatic implementation of :code:`configure_optimizers` can be disabled by linking the configuration group. An example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. This would be: @@ -763,12 +829,11 @@ example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. Th cli = MyLightningCLI(MyModel) -For both possibilities of using :meth:`pytorch_lightning.utilities.cli.LightningArgumentParser.add_optimizer_args` with -a single class or a tuple of classes, the value given to :code:`optimizer_init` will always be a dictionary including -:code:`class_path` and :code:`init_args` entries. The function -:func:`~pytorch_lightning.utilities.cli.instantiate_class` takes care of importing the class defined in -:code:`class_path` and instantiating it using some positional arguments, in this case :code:`self.parameters()`, and the -:code:`init_args`. Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`. +The value given to :code:`optimizer_init` will always be a dictionary including :code:`class_path` and +:code:`init_args` entries. The function :func:`~pytorch_lightning.utilities.cli.instantiate_class` +takes care of importing the class defined in :code:`class_path` and instantiating it using some positional arguments, +in this case :code:`self.parameters()`, and the :code:`init_args`. +Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`. Notes related to reproducibility diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 1d437e69a9..b27a1a12ca 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -11,11 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os +import sys from argparse import Namespace -from types import MethodType +from types import MethodType, ModuleType from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from unittest import mock +import torch from torch.optim import Optimizer from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer @@ -35,9 +39,57 @@ else: ArgumentParser = object +class _Registry(dict): + def __call__(self, cls: Type, key: Optional[str] = None, override: bool = False) -> None: + """Registers a class mapped to a name. + + Args: + cls: the class to be mapped. + key: the name that identifies the provided class. + override: Whether to override an existing key. + """ + if key is None: + key = cls.__name__ + elif not isinstance(key, str): + raise TypeError(f"`key` must be a str, found {key}") + + if key in self and not override: + raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.") + self[key] = cls + + def register_classes(self, module: ModuleType, base_cls: Type, override: bool = False) -> None: + """This function is an utility to register all classes from a module.""" + for _, cls in inspect.getmembers(module, predicate=inspect.isclass): + if issubclass(cls, base_cls) and cls != base_cls: + self(cls=cls, override=override) + + @property + def names(self) -> List[str]: + """Returns the registered names.""" + return list(self.keys()) + + @property + def classes(self) -> Tuple[Type, ...]: + """Returns the registered classes.""" + return tuple(self.values()) + + def __str__(self) -> str: + return f"Registered objects: {self.names}" + + +OPTIMIZER_REGISTRY = _Registry() +OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer) + +LR_SCHEDULER_REGISTRY = _Registry() +LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) + + class LightningArgumentParser(ArgumentParser): """Extension of jsonargparse's ArgumentParser for pytorch-lightning.""" + # use class attribute because `parse_args` is only called on the main parser + _choices: Dict[str, Tuple[Type, ...]] = {} + def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None: """Initialize argument parser that supports configuration file input. @@ -118,6 +170,7 @@ class LightningArgumentParser(ArgumentParser): kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} if isinstance(optimizer_class, tuple): self.add_subclass_arguments(optimizer_class, nested_key, **kwargs) + self.set_choices(nested_key, optimizer_class) else: self.add_class_arguments(optimizer_class, nested_key, **kwargs) self._optimizers[nested_key] = (optimizer_class, link_to) @@ -142,10 +195,70 @@ class LightningArgumentParser(ArgumentParser): kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} if isinstance(lr_scheduler_class, tuple): self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs) + self.set_choices(nested_key, lr_scheduler_class) else: self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs) self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to) + def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: + argv = sys.argv + for k, classes in self._choices.items(): + if not any(arg.startswith(f"--{k}") for arg in argv): + # the key wasn't passed - maybe defined in a config, maybe it's optional + continue + argv = self._convert_argv_issue_84(classes, k, argv) + self._choices.clear() + with mock.patch("sys.argv", argv): + return super().parse_args(*args, **kwargs) + + def set_choices(self, nested_key: str, classes: Tuple[Type, ...]) -> None: + self._choices[nested_key] = classes + + @staticmethod + def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]: + """Placeholder for https://github.com/omni-us/jsonargparse/issues/84. + + This should be removed once implemented. + """ + passed_args, clean_argv = {}, [] + argv_key = f"--{nested_key}" + # get the argv args for this nested key + i = 0 + while i < len(argv): + arg = argv[i] + if arg.startswith(argv_key): + if "=" in arg: + key, value = arg.split("=") + else: + key = arg + i += 1 + value = argv[i] + passed_args[key] = value + else: + clean_argv.append(arg) + i += 1 + # generate the associated config file + argv_class = passed_args.pop(argv_key, None) + if argv_class is None: + # the user passed a config as a str + class_path = passed_args[f"{argv_key}.class_path"] + init_args_key = f"{argv_key}.init_args" + init_args = {k[len(init_args_key) + 1 :]: v for k, v in passed_args.items() if k.startswith(init_args_key)} + config = str({"class_path": class_path, "init_args": init_args}) + elif argv_class.startswith("{"): + # the user passed a config as a dict + config = argv_class + else: + # the user passed the shorthand format + init_args = {k[len(argv_key) + 1 :]: v for k, v in passed_args.items()} # +1 to account for the period + for cls in classes: + if cls.__name__ == argv_class: + config = str(_global_add_class_path(cls, init_args)) + break + else: + raise ValueError(f"Could not generate a config for {repr(argv_class)}") + return clean_argv + [argv_key, config] + class SaveConfigCallback(Callback): """Saves a LightningCLI config to the log_dir when training starts. @@ -328,6 +441,11 @@ class LightningCLI: self.add_default_arguments_to_parser(parser) self.add_core_arguments_to_parser(parser) self.add_arguments_to_parser(parser) + # add default optimizer args if necessary + if not parser._optimizers: # already added by the user in `add_arguments_to_parser` + parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes) + if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser` + parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes) self.link_optimizers_and_lr_schedulers(parser) def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: diff --git a/requirements/extra.txt b/requirements/extra.txt index dfffc6fce8..e7a62d3071 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,6 +7,6 @@ torchtext>=0.7 onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 -jsonargparse[signatures]>=3.19.0 +jsonargparse[signatures]>=3.19.3 gcsfs>=2021.5.0 rich>=10.2.2 diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index e526027708..1cd12a33b7 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -33,7 +33,15 @@ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE -from pytorch_lightning.utilities.cli import instantiate_class, LightningArgumentParser, LightningCLI, SaveConfigCallback +from pytorch_lightning.utilities.cli import ( + instantiate_class, + LightningArgumentParser, + LightningCLI, + LR_SCHEDULER_REGISTRY, + OPTIMIZER_REGISTRY, + SaveConfigCallback, +) +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE from tests.helpers import BoringDataModule, BoringModel from tests.helpers.runif import RunIf @@ -678,12 +686,20 @@ def test_lightning_cli_optimizer_and_lr_scheduler_subclasses(tmpdir): assert cli.trainer.lr_schedulers[0]["scheduler"].step_size == 50 -def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(tmpdir): +@pytest.mark.parametrize("use_registries", [False, True]) +def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_registries, tmpdir): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): - parser.add_optimizer_args(torch.optim.Adam, nested_key="optim1", link_to="model.optim1") + parser.add_optimizer_args( + OPTIMIZER_REGISTRY.classes if use_registries else torch.optim.Adam, + nested_key="optim1", + link_to="model.optim1", + ) parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2") - parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR, link_to="model.scheduler") + parser.add_lr_scheduler_args( + LR_SCHEDULER_REGISTRY.classes if use_registries else torch.optim.lr_scheduler.ExponentialLR, + link_to="model.scheduler", + ) class TestModel(BoringModel): def __init__(self, optim1: dict, optim2: dict, scheduler: dict): @@ -692,20 +708,26 @@ def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(tmpdir): self.optim2 = instantiate_class(self.parameters(), optim2) self.scheduler = instantiate_class(self.optim1, scheduler) - cli_args = [ - "fit", - f"--trainer.default_root_dir={tmpdir}", - "--trainer.max_epochs=1", - "--optim2.class_path=torch.optim.SGD", - "--optim2.init_args.lr=0.01", - "--lr_scheduler.gamma=0.2", - ] + cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", "--lr_scheduler.gamma=0.2"] + if use_registries: + cli_args += [ + "--optim1", + "Adam", + "--optim1.weight_decay", + "0.001", + "--optim2=SGD", + "--optim2.lr=0.01", + "--lr_scheduler=ExponentialLR", + ] + else: + cli_args += ["--optim2.class_path=torch.optim.SGD", "--optim2.init_args.lr=0.01"] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(TestModel) assert isinstance(cli.model.optim1, torch.optim.Adam) assert isinstance(cli.model.optim2, torch.optim.SGD) + assert cli.model.optim2.param_groups[0]["lr"] == 0.01 assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR) @@ -829,6 +851,190 @@ def test_lightning_cli_run(): assert isinstance(cli.model, LightningModule) +@OPTIMIZER_REGISTRY +class CustomAdam(torch.optim.Adam): + pass + + +@LR_SCHEDULER_REGISTRY +class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR): + pass + + +def test_registries(tmpdir): + assert "SGD" in OPTIMIZER_REGISTRY.names + assert "RMSprop" in OPTIMIZER_REGISTRY.names + assert "CustomAdam" in OPTIMIZER_REGISTRY.names + + assert "CosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names + assert "CosineAnnealingWarmRestarts" in LR_SCHEDULER_REGISTRY.names + assert "CustomCosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names + + with pytest.raises(MisconfigurationException, match="is already present in the registry"): + OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer) + OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer, override=True) + + +def test_registries_resolution(): + """This test validates registries are used when simplified command line are being used.""" + cli_args = [ + "--optimizer", + "Adam", + "--optimizer.lr", + "0.0001", + "--lr_scheduler", + "StepLR", + "--lr_scheduler.step_size=50", + ] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = LightningCLI(BoringModel, run=False) + + optimizers, lr_scheduler = cli.model.configure_optimizers() + assert isinstance(optimizers[0], torch.optim.Adam) + assert optimizers[0].param_groups[0]["lr"] == 0.0001 + assert lr_scheduler[0].step_size == 50 + + +@pytest.mark.parametrize( + ["args", "expected", "nested_key", "registry"], + [ + ( + ["--optimizer", "Adadelta"], + {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {}}, + "optimizer", + OPTIMIZER_REGISTRY, + ), + ( + ["--optimizer", "Adadelta", "--optimizer.lr", "10"], + {"class_path": "torch.optim.adadelta.Adadelta", "init_args": {"lr": "10"}}, + "optimizer", + OPTIMIZER_REGISTRY, + ), + ( + ["--lr_scheduler", "OneCycleLR"], + {"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {}}, + "lr_scheduler", + LR_SCHEDULER_REGISTRY, + ), + ( + ["--lr_scheduler", "OneCycleLR", "--lr_scheduler.anneal_strategy=linear"], + {"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {"anneal_strategy": "linear"}}, + "lr_scheduler", + LR_SCHEDULER_REGISTRY, + ), + ], +) +def test_argv_transformations_with_optimizers_and_lr_schedulers(args, expected, nested_key, registry): + base = ["any.py", "--trainer.max_epochs=1"] + argv = base + args + new_argv = LightningArgumentParser._convert_argv_issue_84(registry.classes, nested_key, argv) + assert new_argv == base + [f"--{nested_key}", str(expected)] + + +def test_optimizers_and_lr_schedulers_reload(tmpdir): + base = ["any.py", "--trainer.max_epochs=1"] + input = base + [ + "--lr_scheduler", + "OneCycleLR", + "--lr_scheduler.total_steps=10", + "--lr_scheduler.max_lr=1", + "--optimizer", + "Adam", + "--optimizer.lr=0.1", + ] + + # save config + out = StringIO() + with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit): + LightningCLI(BoringModel, run=False) + + # validate yaml + yaml_config = out.getvalue() + dict_config = yaml.safe_load(yaml_config) + assert dict_config["optimizer"]["class_path"] == "torch.optim.adam.Adam" + assert dict_config["optimizer"]["init_args"]["lr"] == 0.1 + assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR" + + # reload config + yaml_config_file = tmpdir / "config.yaml" + yaml_config_file.write_text(yaml_config, "utf-8") + with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]): + LightningCLI(BoringModel, run=False) + + +def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload(tmpdir): + class TestLightningCLI(LightningCLI): + def __init__(self, *args): + super().__init__(*args, run=False) + + def add_arguments_to_parser(self, parser): + parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes, nested_key="opt1", link_to="model.opt1_config") + parser.add_optimizer_args( + (torch.optim.ASGD, torch.optim.SGD), nested_key="opt2", link_to="model.opt2_config" + ) + parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes, link_to="model.sch_config") + parser.add_argument("--something", type=str, nargs="+") + + class TestModel(BoringModel): + def __init__(self, opt1_config: dict, opt2_config: dict, sch_config: dict): + super().__init__() + self.opt1_config = opt1_config + self.opt2_config = opt2_config + self.sch_config = sch_config + opt1 = instantiate_class(self.parameters(), opt1_config) + assert isinstance(opt1, torch.optim.Adam) + opt2 = instantiate_class(self.parameters(), opt2_config) + assert isinstance(opt2, torch.optim.ASGD) + sch = instantiate_class(opt1, sch_config) + assert isinstance(sch, torch.optim.lr_scheduler.OneCycleLR) + + base = ["any.py", "--trainer.max_epochs=1"] + input = base + [ + "--lr_scheduler", + "OneCycleLR", + "--lr_scheduler.total_steps=10", + "--lr_scheduler.max_lr=1", + "--opt1", + "Adam", + "--opt2.lr=0.1", + "--opt2", + "ASGD", + "--lr_scheduler.anneal_strategy=linear", + "--something", + "a", + "b", + "c", + ] + + # save config + out = StringIO() + with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit): + TestLightningCLI(TestModel) + + # validate yaml + yaml_config = out.getvalue() + dict_config = yaml.safe_load(yaml_config) + assert dict_config["opt1"]["class_path"] == "torch.optim.adam.Adam" + assert dict_config["opt2"]["class_path"] == "torch.optim.asgd.ASGD" + assert dict_config["opt2"]["init_args"]["lr"] == 0.1 + assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR" + assert dict_config["lr_scheduler"]["init_args"]["anneal_strategy"] == "linear" + assert dict_config["something"] == ["a", "b", "c"] + + # reload config + yaml_config_file = tmpdir / "config.yaml" + yaml_config_file.write_text(yaml_config, "utf-8") + with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]): + cli = TestLightningCLI(TestModel) + + assert cli.model.opt1_config["class_path"] == "torch.optim.adam.Adam" + assert cli.model.opt2_config["class_path"] == "torch.optim.asgd.ASGD" + assert cli.model.opt2_config["init_args"]["lr"] == 0.1 + assert cli.model.sch_config["class_path"] == "torch.optim.lr_scheduler.OneCycleLR" + assert cli.model.sch_config["init_args"]["anneal_strategy"] == "linear" + + @RunIf(min_python="3.7.3") # bpo-17185: `autospec=True` and `inspect.signature` do not play well def test_lightning_cli_config_with_subcommand(): config = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}}