[CLI] Shorthand notation to instantiate optimizers and lr schedulers [2/3] (#9565)

This commit is contained in:
Carlos Mocholí 2021-09-17 19:00:46 +02:00 committed by GitHub
parent 77c719f98d
commit bbcb977851
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 451 additions and 59 deletions

View File

@ -54,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `LightningCLI(run=False|True)` to choose whether to run a `Trainer` subcommand ([#8751](https://github.com/PyTorchLightning/pytorch-lightning/pull/8751))
* Added support to call any trainer function from the `LightningCLI` via subcommands ([#7508](https://github.com/PyTorchLightning/pytorch-lightning/pull/7508))
* Allow easy trainer re-instantiation ([#7508](https://github.com/PyTorchLightning/pytorch-lightning/pull/9241))
* Automatically register all optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
* Allow registering custom optimizers and learning rate schedulers without subclassing the CLI ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
* Support shorthand notation to instantiate optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
- Fault-tolerant training:

View File

@ -665,27 +665,118 @@ Optimizers and learning rate schedulers
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Optimizers and learning rate schedulers can also be made configurable. The most common case is when a model only has a
single optimizer and optionally a single learning rate scheduler. In this case the model's
:class:`~pytorch_lightning.core.lightning.LightningModule` could be left without implementing the
:code:`configure_optimizers` method since it is normally always the same and just adds boilerplate. The following code
snippet shows how to implement it:
single optimizer and optionally a single learning rate scheduler. In this case, the model's
:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers` could be left unimplemented since it is
normally always the same and just adds boilerplate.
The CLI works out-of-the-box with PyTorch's built-in optimizers and learning rate schedulers when
at most one of each is used.
Only the optimizer or scheduler name needs to be passed, optionally with its ``__init__`` arguments:
.. code-block:: bash
$ python trainer.py fit --optimizer=Adam --optimizer.lr=0.01 --lr_scheduler=ExponentialLR --lr_scheduler.gamma=0.1
A corresponding example of the config file would be:
.. code-block:: yaml
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.01
lr_scheduler:
class_path: torch.optim.lr_scheduler.ExponentialLR
init_args:
gamma: 0.1
model:
...
trainer:
...
.. note::
This short-hand notation is only supported in the shell and not inside a configuration file. The configuration file
generated by calling the previous command with ``--print_config`` will have the ``class_path`` notation.
Furthermore, you can register your own optimizers and/or learning rate schedulers as follows:
.. code-block:: python
from pytorch_lightning.utilities.cli import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY
@OPTIMIZER_REGISTRY
class CustomAdam(torch.optim.Adam):
...
@LR_SCHEDULER_REGISTRY
class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
...
# register all `Optimizer` subclasses from the `torch.optim` package
# This is done automatically!
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer)
cli = LightningCLI(...)
.. code-block:: bash
$ python trainer.py fit --optimizer=CustomAdam --optimizer.lr=0.01 --lr_scheduler=CustomCosineAnnealingLR
If you need to customize the key names or link arguments together, you can choose from all available optimizers and
learning rate schedulers by accessing the registries.
.. code-block::
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(
OPTIMIZER_REGISTRY.classes,
nested_key="gen_optimizer",
link_to="model.optimizer1_init"
)
parser.add_optimizer_args(
OPTIMIZER_REGISTRY.classes,
nested_key="gen_discriminator",
link_to="model.optimizer2_init"
)
.. code-block:: bash
$ python trainer.py fit \
--gen_optimizer=Adam \
--gen_optimizer.lr=0.01 \
--gen_discriminator=AdamW \
--gen_discriminator.lr=0.0001
You can also use pass the class path directly, for example, if the optimizer hasn't been registered to the
``OPTIMIZER_REGISTRY``:
.. code-block:: bash
$ python trainer.py fit \
--gen_optimizer.class_path=torch.optim.Adam \
--gen_optimizer.init_args.lr=0.01 \
--gen_discriminator.class_path=torch.optim.AdamW \
--gen_discriminator.init_args.lr=0.0001
If you will not be changing the class, you can manually add the arguments for specific optimizers and/or
learning rate schedulers by subclassing the CLI. This has the advantage of providing the proper help message for those
classes. The following code snippet shows how to implement it:
.. testcode::
import torch
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(torch.optim.Adam)
parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR)
cli = MyLightningCLI(MyModel)
With this the :code:`configure_optimizers` method is automatically implemented and in the config the :code:`optimizer`
and :code:`lr_scheduler` groups would accept all of the options for the given classes, in this example :code:`Adam` and
:code:`ExponentialLR`. Therefore, the config file would be structured like:
With this, in the config the :code:`optimizer` and :code:`lr_scheduler` groups would accept all of the options for the
given classes, in this example :code:`Adam` and :code:`ExponentialLR`.
Therefore, the config file would be structured like:
.. code-block:: yaml
@ -698,37 +789,12 @@ and :code:`lr_scheduler` groups would accept all of the options for the given cl
trainer:
...
And any of these arguments could be passed directly through command line. For example:
Where the arguments can be passed directly through command line without specifying the class. For example:
.. code-block:: bash
$ python trainer.py fit --optimizer.lr=0.01 --lr_scheduler.gamma=0.2
There is also the possibility of selecting among multiple classes by giving them as a tuple. For example:
.. testcode::
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args((torch.optim.SGD, torch.optim.Adam))
In this case in the config the :code:`optimizer` group instead of having directly init settings, it should specify
:code:`class_path` and optionally :code:`init_args`. Sub-classes of the classes in the tuple would also be accepted.
A corresponding example of the config file would be:
.. code-block:: yaml
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.01
And the same through command line:
.. code-block:: bash
$ python trainer.py fit --optimizer.class_path=torch.optim.Adam --optimizer.init_args.lr=0.01
The automatic implementation of :code:`configure_optimizers` can be disabled by linking the configuration group. An
example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. This would be:
@ -763,12 +829,11 @@ example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. Th
cli = MyLightningCLI(MyModel)
For both possibilities of using :meth:`pytorch_lightning.utilities.cli.LightningArgumentParser.add_optimizer_args` with
a single class or a tuple of classes, the value given to :code:`optimizer_init` will always be a dictionary including
:code:`class_path` and :code:`init_args` entries. The function
:func:`~pytorch_lightning.utilities.cli.instantiate_class` takes care of importing the class defined in
:code:`class_path` and instantiating it using some positional arguments, in this case :code:`self.parameters()`, and the
:code:`init_args`. Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`.
The value given to :code:`optimizer_init` will always be a dictionary including :code:`class_path` and
:code:`init_args` entries. The function :func:`~pytorch_lightning.utilities.cli.instantiate_class`
takes care of importing the class defined in :code:`class_path` and instantiating it using some positional arguments,
in this case :code:`self.parameters()`, and the :code:`init_args`.
Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`.
Notes related to reproducibility

View File

@ -11,11 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import sys
from argparse import Namespace
from types import MethodType
from types import MethodType, ModuleType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from unittest import mock
import torch
from torch.optim import Optimizer
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
@ -35,9 +39,57 @@ else:
ArgumentParser = object
class _Registry(dict):
def __call__(self, cls: Type, key: Optional[str] = None, override: bool = False) -> None:
"""Registers a class mapped to a name.
Args:
cls: the class to be mapped.
key: the name that identifies the provided class.
override: Whether to override an existing key.
"""
if key is None:
key = cls.__name__
elif not isinstance(key, str):
raise TypeError(f"`key` must be a str, found {key}")
if key in self and not override:
raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.")
self[key] = cls
def register_classes(self, module: ModuleType, base_cls: Type, override: bool = False) -> None:
"""This function is an utility to register all classes from a module."""
for _, cls in inspect.getmembers(module, predicate=inspect.isclass):
if issubclass(cls, base_cls) and cls != base_cls:
self(cls=cls, override=override)
@property
def names(self) -> List[str]:
"""Returns the registered names."""
return list(self.keys())
@property
def classes(self) -> Tuple[Type, ...]:
"""Returns the registered classes."""
return tuple(self.values())
def __str__(self) -> str:
return f"Registered objects: {self.names}"
OPTIMIZER_REGISTRY = _Registry()
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer)
LR_SCHEDULER_REGISTRY = _Registry()
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)
class LightningArgumentParser(ArgumentParser):
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""
# use class attribute because `parse_args` is only called on the main parser
_choices: Dict[str, Tuple[Type, ...]] = {}
def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None:
"""Initialize argument parser that supports configuration file input.
@ -118,6 +170,7 @@ class LightningArgumentParser(ArgumentParser):
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}}
if isinstance(optimizer_class, tuple):
self.add_subclass_arguments(optimizer_class, nested_key, **kwargs)
self.set_choices(nested_key, optimizer_class)
else:
self.add_class_arguments(optimizer_class, nested_key, **kwargs)
self._optimizers[nested_key] = (optimizer_class, link_to)
@ -142,10 +195,70 @@ class LightningArgumentParser(ArgumentParser):
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}}
if isinstance(lr_scheduler_class, tuple):
self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs)
self.set_choices(nested_key, lr_scheduler_class)
else:
self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs)
self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to)
def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
argv = sys.argv
for k, classes in self._choices.items():
if not any(arg.startswith(f"--{k}") for arg in argv):
# the key wasn't passed - maybe defined in a config, maybe it's optional
continue
argv = self._convert_argv_issue_84(classes, k, argv)
self._choices.clear()
with mock.patch("sys.argv", argv):
return super().parse_args(*args, **kwargs)
def set_choices(self, nested_key: str, classes: Tuple[Type, ...]) -> None:
self._choices[nested_key] = classes
@staticmethod
def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]:
"""Placeholder for https://github.com/omni-us/jsonargparse/issues/84.
This should be removed once implemented.
"""
passed_args, clean_argv = {}, []
argv_key = f"--{nested_key}"
# get the argv args for this nested key
i = 0
while i < len(argv):
arg = argv[i]
if arg.startswith(argv_key):
if "=" in arg:
key, value = arg.split("=")
else:
key = arg
i += 1
value = argv[i]
passed_args[key] = value
else:
clean_argv.append(arg)
i += 1
# generate the associated config file
argv_class = passed_args.pop(argv_key, None)
if argv_class is None:
# the user passed a config as a str
class_path = passed_args[f"{argv_key}.class_path"]
init_args_key = f"{argv_key}.init_args"
init_args = {k[len(init_args_key) + 1 :]: v for k, v in passed_args.items() if k.startswith(init_args_key)}
config = str({"class_path": class_path, "init_args": init_args})
elif argv_class.startswith("{"):
# the user passed a config as a dict
config = argv_class
else:
# the user passed the shorthand format
init_args = {k[len(argv_key) + 1 :]: v for k, v in passed_args.items()} # +1 to account for the period
for cls in classes:
if cls.__name__ == argv_class:
config = str(_global_add_class_path(cls, init_args))
break
else:
raise ValueError(f"Could not generate a config for {repr(argv_class)}")
return clean_argv + [argv_key, config]
class SaveConfigCallback(Callback):
"""Saves a LightningCLI config to the log_dir when training starts.
@ -328,6 +441,11 @@ class LightningCLI:
self.add_default_arguments_to_parser(parser)
self.add_core_arguments_to_parser(parser)
self.add_arguments_to_parser(parser)
# add default optimizer args if necessary
if not parser._optimizers: # already added by the user in `add_arguments_to_parser`
parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes)
if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser`
parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes)
self.link_optimizers_and_lr_schedulers(parser)
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:

View File

@ -7,6 +7,6 @@ torchtext>=0.7
onnx>=1.7.0
onnxruntime>=1.3.0
hydra-core>=1.0
jsonargparse[signatures]>=3.19.0
jsonargparse[signatures]>=3.19.3
gcsfs>=2021.5.0
rich>=10.2.2

View File

@ -33,7 +33,15 @@ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.cli import instantiate_class, LightningArgumentParser, LightningCLI, SaveConfigCallback
from pytorch_lightning.utilities.cli import (
instantiate_class,
LightningArgumentParser,
LightningCLI,
LR_SCHEDULER_REGISTRY,
OPTIMIZER_REGISTRY,
SaveConfigCallback,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
@ -678,12 +686,20 @@ def test_lightning_cli_optimizer_and_lr_scheduler_subclasses(tmpdir):
assert cli.trainer.lr_schedulers[0]["scheduler"].step_size == 50
def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(tmpdir):
@pytest.mark.parametrize("use_registries", [False, True])
def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_registries, tmpdir):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(torch.optim.Adam, nested_key="optim1", link_to="model.optim1")
parser.add_optimizer_args(
OPTIMIZER_REGISTRY.classes if use_registries else torch.optim.Adam,
nested_key="optim1",
link_to="model.optim1",
)
parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2")
parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR, link_to="model.scheduler")
parser.add_lr_scheduler_args(
LR_SCHEDULER_REGISTRY.classes if use_registries else torch.optim.lr_scheduler.ExponentialLR,
link_to="model.scheduler",
)
class TestModel(BoringModel):
def __init__(self, optim1: dict, optim2: dict, scheduler: dict):
@ -692,20 +708,26 @@ def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(tmpdir):
self.optim2 = instantiate_class(self.parameters(), optim2)
self.scheduler = instantiate_class(self.optim1, scheduler)
cli_args = [
"fit",
f"--trainer.default_root_dir={tmpdir}",
"--trainer.max_epochs=1",
"--optim2.class_path=torch.optim.SGD",
"--optim2.init_args.lr=0.01",
"--lr_scheduler.gamma=0.2",
]
cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", "--lr_scheduler.gamma=0.2"]
if use_registries:
cli_args += [
"--optim1",
"Adam",
"--optim1.weight_decay",
"0.001",
"--optim2=SGD",
"--optim2.lr=0.01",
"--lr_scheduler=ExponentialLR",
]
else:
cli_args += ["--optim2.class_path=torch.optim.SGD", "--optim2.init_args.lr=0.01"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(TestModel)
assert isinstance(cli.model.optim1, torch.optim.Adam)
assert isinstance(cli.model.optim2, torch.optim.SGD)
assert cli.model.optim2.param_groups[0]["lr"] == 0.01
assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)
@ -829,6 +851,190 @@ def test_lightning_cli_run():
assert isinstance(cli.model, LightningModule)
@OPTIMIZER_REGISTRY
class CustomAdam(torch.optim.Adam):
pass
@LR_SCHEDULER_REGISTRY
class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
pass
def test_registries(tmpdir):
assert "SGD" in OPTIMIZER_REGISTRY.names
assert "RMSprop" in OPTIMIZER_REGISTRY.names
assert "CustomAdam" in OPTIMIZER_REGISTRY.names
assert "CosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names
assert "CosineAnnealingWarmRestarts" in LR_SCHEDULER_REGISTRY.names
assert "CustomCosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names
with pytest.raises(MisconfigurationException, match="is already present in the registry"):
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer)
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer, override=True)
def test_registries_resolution():
"""This test validates registries are used when simplified command line are being used."""
cli_args = [
"--optimizer",
"Adam",
"--optimizer.lr",
"0.0001",
"--lr_scheduler",
"StepLR",
"--lr_scheduler.step_size=50",
]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(BoringModel, run=False)
optimizers, lr_scheduler = cli.model.configure_optimizers()
assert isinstance(optimizers[0], torch.optim.Adam)
assert optimizers[0].param_groups[0]["lr"] == 0.0001
assert lr_scheduler[0].step_size == 50
@pytest.mark.parametrize(
["args", "expected", "nested_key", "registry"],
[
(
["--optimizer", "Adadelta"],
{"class_path": "torch.optim.adadelta.Adadelta", "init_args": {}},
"optimizer",
OPTIMIZER_REGISTRY,
),
(
["--optimizer", "Adadelta", "--optimizer.lr", "10"],
{"class_path": "torch.optim.adadelta.Adadelta", "init_args": {"lr": "10"}},
"optimizer",
OPTIMIZER_REGISTRY,
),
(
["--lr_scheduler", "OneCycleLR"],
{"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {}},
"lr_scheduler",
LR_SCHEDULER_REGISTRY,
),
(
["--lr_scheduler", "OneCycleLR", "--lr_scheduler.anneal_strategy=linear"],
{"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {"anneal_strategy": "linear"}},
"lr_scheduler",
LR_SCHEDULER_REGISTRY,
),
],
)
def test_argv_transformations_with_optimizers_and_lr_schedulers(args, expected, nested_key, registry):
base = ["any.py", "--trainer.max_epochs=1"]
argv = base + args
new_argv = LightningArgumentParser._convert_argv_issue_84(registry.classes, nested_key, argv)
assert new_argv == base + [f"--{nested_key}", str(expected)]
def test_optimizers_and_lr_schedulers_reload(tmpdir):
base = ["any.py", "--trainer.max_epochs=1"]
input = base + [
"--lr_scheduler",
"OneCycleLR",
"--lr_scheduler.total_steps=10",
"--lr_scheduler.max_lr=1",
"--optimizer",
"Adam",
"--optimizer.lr=0.1",
]
# save config
out = StringIO()
with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit):
LightningCLI(BoringModel, run=False)
# validate yaml
yaml_config = out.getvalue()
dict_config = yaml.safe_load(yaml_config)
assert dict_config["optimizer"]["class_path"] == "torch.optim.adam.Adam"
assert dict_config["optimizer"]["init_args"]["lr"] == 0.1
assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR"
# reload config
yaml_config_file = tmpdir / "config.yaml"
yaml_config_file.write_text(yaml_config, "utf-8")
with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]):
LightningCLI(BoringModel, run=False)
def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload(tmpdir):
class TestLightningCLI(LightningCLI):
def __init__(self, *args):
super().__init__(*args, run=False)
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes, nested_key="opt1", link_to="model.opt1_config")
parser.add_optimizer_args(
(torch.optim.ASGD, torch.optim.SGD), nested_key="opt2", link_to="model.opt2_config"
)
parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes, link_to="model.sch_config")
parser.add_argument("--something", type=str, nargs="+")
class TestModel(BoringModel):
def __init__(self, opt1_config: dict, opt2_config: dict, sch_config: dict):
super().__init__()
self.opt1_config = opt1_config
self.opt2_config = opt2_config
self.sch_config = sch_config
opt1 = instantiate_class(self.parameters(), opt1_config)
assert isinstance(opt1, torch.optim.Adam)
opt2 = instantiate_class(self.parameters(), opt2_config)
assert isinstance(opt2, torch.optim.ASGD)
sch = instantiate_class(opt1, sch_config)
assert isinstance(sch, torch.optim.lr_scheduler.OneCycleLR)
base = ["any.py", "--trainer.max_epochs=1"]
input = base + [
"--lr_scheduler",
"OneCycleLR",
"--lr_scheduler.total_steps=10",
"--lr_scheduler.max_lr=1",
"--opt1",
"Adam",
"--opt2.lr=0.1",
"--opt2",
"ASGD",
"--lr_scheduler.anneal_strategy=linear",
"--something",
"a",
"b",
"c",
]
# save config
out = StringIO()
with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit):
TestLightningCLI(TestModel)
# validate yaml
yaml_config = out.getvalue()
dict_config = yaml.safe_load(yaml_config)
assert dict_config["opt1"]["class_path"] == "torch.optim.adam.Adam"
assert dict_config["opt2"]["class_path"] == "torch.optim.asgd.ASGD"
assert dict_config["opt2"]["init_args"]["lr"] == 0.1
assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR"
assert dict_config["lr_scheduler"]["init_args"]["anneal_strategy"] == "linear"
assert dict_config["something"] == ["a", "b", "c"]
# reload config
yaml_config_file = tmpdir / "config.yaml"
yaml_config_file.write_text(yaml_config, "utf-8")
with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]):
cli = TestLightningCLI(TestModel)
assert cli.model.opt1_config["class_path"] == "torch.optim.adam.Adam"
assert cli.model.opt2_config["class_path"] == "torch.optim.asgd.ASGD"
assert cli.model.opt2_config["init_args"]["lr"] == 0.1
assert cli.model.sch_config["class_path"] == "torch.optim.lr_scheduler.OneCycleLR"
assert cli.model.sch_config["init_args"]["anneal_strategy"] == "linear"
@RunIf(min_python="3.7.3") # bpo-17185: `autospec=True` and `inspect.signature` do not play well
def test_lightning_cli_config_with_subcommand():
config = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}}