[CLI] Shorthand notation to instantiate optimizers and lr schedulers [2/3] (#9565)
This commit is contained in:
parent
77c719f98d
commit
bbcb977851
|
@ -54,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Added `LightningCLI(run=False|True)` to choose whether to run a `Trainer` subcommand ([#8751](https://github.com/PyTorchLightning/pytorch-lightning/pull/8751))
|
||||
* Added support to call any trainer function from the `LightningCLI` via subcommands ([#7508](https://github.com/PyTorchLightning/pytorch-lightning/pull/7508))
|
||||
* Allow easy trainer re-instantiation ([#7508](https://github.com/PyTorchLightning/pytorch-lightning/pull/9241))
|
||||
* Automatically register all optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
|
||||
* Allow registering custom optimizers and learning rate schedulers without subclassing the CLI ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
|
||||
* Support shorthand notation to instantiate optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
|
||||
|
||||
|
||||
- Fault-tolerant training:
|
||||
|
|
|
@ -665,27 +665,118 @@ Optimizers and learning rate schedulers
|
|||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Optimizers and learning rate schedulers can also be made configurable. The most common case is when a model only has a
|
||||
single optimizer and optionally a single learning rate scheduler. In this case the model's
|
||||
:class:`~pytorch_lightning.core.lightning.LightningModule` could be left without implementing the
|
||||
:code:`configure_optimizers` method since it is normally always the same and just adds boilerplate. The following code
|
||||
snippet shows how to implement it:
|
||||
single optimizer and optionally a single learning rate scheduler. In this case, the model's
|
||||
:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers` could be left unimplemented since it is
|
||||
normally always the same and just adds boilerplate.
|
||||
|
||||
The CLI works out-of-the-box with PyTorch's built-in optimizers and learning rate schedulers when
|
||||
at most one of each is used.
|
||||
Only the optimizer or scheduler name needs to be passed, optionally with its ``__init__`` arguments:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python trainer.py fit --optimizer=Adam --optimizer.lr=0.01 --lr_scheduler=ExponentialLR --lr_scheduler.gamma=0.1
|
||||
|
||||
A corresponding example of the config file would be:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
optimizer:
|
||||
class_path: torch.optim.Adam
|
||||
init_args:
|
||||
lr: 0.01
|
||||
lr_scheduler:
|
||||
class_path: torch.optim.lr_scheduler.ExponentialLR
|
||||
init_args:
|
||||
gamma: 0.1
|
||||
model:
|
||||
...
|
||||
trainer:
|
||||
...
|
||||
|
||||
.. note::
|
||||
|
||||
This short-hand notation is only supported in the shell and not inside a configuration file. The configuration file
|
||||
generated by calling the previous command with ``--print_config`` will have the ``class_path`` notation.
|
||||
|
||||
Furthermore, you can register your own optimizers and/or learning rate schedulers as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning.utilities.cli import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY
|
||||
|
||||
|
||||
@OPTIMIZER_REGISTRY
|
||||
class CustomAdam(torch.optim.Adam):
|
||||
...
|
||||
|
||||
|
||||
@LR_SCHEDULER_REGISTRY
|
||||
class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
|
||||
...
|
||||
|
||||
|
||||
# register all `Optimizer` subclasses from the `torch.optim` package
|
||||
# This is done automatically!
|
||||
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer)
|
||||
|
||||
cli = LightningCLI(...)
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python trainer.py fit --optimizer=CustomAdam --optimizer.lr=0.01 --lr_scheduler=CustomCosineAnnealingLR
|
||||
|
||||
If you need to customize the key names or link arguments together, you can choose from all available optimizers and
|
||||
learning rate schedulers by accessing the registries.
|
||||
|
||||
.. code-block::
|
||||
|
||||
class MyLightningCLI(LightningCLI):
|
||||
def add_arguments_to_parser(self, parser):
|
||||
parser.add_optimizer_args(
|
||||
OPTIMIZER_REGISTRY.classes,
|
||||
nested_key="gen_optimizer",
|
||||
link_to="model.optimizer1_init"
|
||||
)
|
||||
parser.add_optimizer_args(
|
||||
OPTIMIZER_REGISTRY.classes,
|
||||
nested_key="gen_discriminator",
|
||||
link_to="model.optimizer2_init"
|
||||
)
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python trainer.py fit \
|
||||
--gen_optimizer=Adam \
|
||||
--gen_optimizer.lr=0.01 \
|
||||
--gen_discriminator=AdamW \
|
||||
--gen_discriminator.lr=0.0001
|
||||
|
||||
You can also use pass the class path directly, for example, if the optimizer hasn't been registered to the
|
||||
``OPTIMIZER_REGISTRY``:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python trainer.py fit \
|
||||
--gen_optimizer.class_path=torch.optim.Adam \
|
||||
--gen_optimizer.init_args.lr=0.01 \
|
||||
--gen_discriminator.class_path=torch.optim.AdamW \
|
||||
--gen_discriminator.init_args.lr=0.0001
|
||||
|
||||
If you will not be changing the class, you can manually add the arguments for specific optimizers and/or
|
||||
learning rate schedulers by subclassing the CLI. This has the advantage of providing the proper help message for those
|
||||
classes. The following code snippet shows how to implement it:
|
||||
|
||||
.. testcode::
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MyLightningCLI(LightningCLI):
|
||||
def add_arguments_to_parser(self, parser):
|
||||
parser.add_optimizer_args(torch.optim.Adam)
|
||||
parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR)
|
||||
|
||||
|
||||
cli = MyLightningCLI(MyModel)
|
||||
|
||||
With this the :code:`configure_optimizers` method is automatically implemented and in the config the :code:`optimizer`
|
||||
and :code:`lr_scheduler` groups would accept all of the options for the given classes, in this example :code:`Adam` and
|
||||
:code:`ExponentialLR`. Therefore, the config file would be structured like:
|
||||
With this, in the config the :code:`optimizer` and :code:`lr_scheduler` groups would accept all of the options for the
|
||||
given classes, in this example :code:`Adam` and :code:`ExponentialLR`.
|
||||
Therefore, the config file would be structured like:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
|
@ -698,37 +789,12 @@ and :code:`lr_scheduler` groups would accept all of the options for the given cl
|
|||
trainer:
|
||||
...
|
||||
|
||||
And any of these arguments could be passed directly through command line. For example:
|
||||
Where the arguments can be passed directly through command line without specifying the class. For example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python trainer.py fit --optimizer.lr=0.01 --lr_scheduler.gamma=0.2
|
||||
|
||||
There is also the possibility of selecting among multiple classes by giving them as a tuple. For example:
|
||||
|
||||
.. testcode::
|
||||
|
||||
class MyLightningCLI(LightningCLI):
|
||||
def add_arguments_to_parser(self, parser):
|
||||
parser.add_optimizer_args((torch.optim.SGD, torch.optim.Adam))
|
||||
|
||||
In this case in the config the :code:`optimizer` group instead of having directly init settings, it should specify
|
||||
:code:`class_path` and optionally :code:`init_args`. Sub-classes of the classes in the tuple would also be accepted.
|
||||
A corresponding example of the config file would be:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
optimizer:
|
||||
class_path: torch.optim.Adam
|
||||
init_args:
|
||||
lr: 0.01
|
||||
|
||||
And the same through command line:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python trainer.py fit --optimizer.class_path=torch.optim.Adam --optimizer.init_args.lr=0.01
|
||||
|
||||
The automatic implementation of :code:`configure_optimizers` can be disabled by linking the configuration group. An
|
||||
example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. This would be:
|
||||
|
||||
|
@ -763,12 +829,11 @@ example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. Th
|
|||
|
||||
cli = MyLightningCLI(MyModel)
|
||||
|
||||
For both possibilities of using :meth:`pytorch_lightning.utilities.cli.LightningArgumentParser.add_optimizer_args` with
|
||||
a single class or a tuple of classes, the value given to :code:`optimizer_init` will always be a dictionary including
|
||||
:code:`class_path` and :code:`init_args` entries. The function
|
||||
:func:`~pytorch_lightning.utilities.cli.instantiate_class` takes care of importing the class defined in
|
||||
:code:`class_path` and instantiating it using some positional arguments, in this case :code:`self.parameters()`, and the
|
||||
:code:`init_args`. Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`.
|
||||
The value given to :code:`optimizer_init` will always be a dictionary including :code:`class_path` and
|
||||
:code:`init_args` entries. The function :func:`~pytorch_lightning.utilities.cli.instantiate_class`
|
||||
takes care of importing the class defined in :code:`class_path` and instantiating it using some positional arguments,
|
||||
in this case :code:`self.parameters()`, and the :code:`init_args`.
|
||||
Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`.
|
||||
|
||||
|
||||
Notes related to reproducibility
|
||||
|
|
|
@ -11,11 +11,15 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from types import MethodType
|
||||
from types import MethodType, ModuleType
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
|
||||
|
@ -35,9 +39,57 @@ else:
|
|||
ArgumentParser = object
|
||||
|
||||
|
||||
class _Registry(dict):
|
||||
def __call__(self, cls: Type, key: Optional[str] = None, override: bool = False) -> None:
|
||||
"""Registers a class mapped to a name.
|
||||
|
||||
Args:
|
||||
cls: the class to be mapped.
|
||||
key: the name that identifies the provided class.
|
||||
override: Whether to override an existing key.
|
||||
"""
|
||||
if key is None:
|
||||
key = cls.__name__
|
||||
elif not isinstance(key, str):
|
||||
raise TypeError(f"`key` must be a str, found {key}")
|
||||
|
||||
if key in self and not override:
|
||||
raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.")
|
||||
self[key] = cls
|
||||
|
||||
def register_classes(self, module: ModuleType, base_cls: Type, override: bool = False) -> None:
|
||||
"""This function is an utility to register all classes from a module."""
|
||||
for _, cls in inspect.getmembers(module, predicate=inspect.isclass):
|
||||
if issubclass(cls, base_cls) and cls != base_cls:
|
||||
self(cls=cls, override=override)
|
||||
|
||||
@property
|
||||
def names(self) -> List[str]:
|
||||
"""Returns the registered names."""
|
||||
return list(self.keys())
|
||||
|
||||
@property
|
||||
def classes(self) -> Tuple[Type, ...]:
|
||||
"""Returns the registered classes."""
|
||||
return tuple(self.values())
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Registered objects: {self.names}"
|
||||
|
||||
|
||||
OPTIMIZER_REGISTRY = _Registry()
|
||||
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer)
|
||||
|
||||
LR_SCHEDULER_REGISTRY = _Registry()
|
||||
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)
|
||||
|
||||
|
||||
class LightningArgumentParser(ArgumentParser):
|
||||
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""
|
||||
|
||||
# use class attribute because `parse_args` is only called on the main parser
|
||||
_choices: Dict[str, Tuple[Type, ...]] = {}
|
||||
|
||||
def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None:
|
||||
"""Initialize argument parser that supports configuration file input.
|
||||
|
||||
|
@ -118,6 +170,7 @@ class LightningArgumentParser(ArgumentParser):
|
|||
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}}
|
||||
if isinstance(optimizer_class, tuple):
|
||||
self.add_subclass_arguments(optimizer_class, nested_key, **kwargs)
|
||||
self.set_choices(nested_key, optimizer_class)
|
||||
else:
|
||||
self.add_class_arguments(optimizer_class, nested_key, **kwargs)
|
||||
self._optimizers[nested_key] = (optimizer_class, link_to)
|
||||
|
@ -142,10 +195,70 @@ class LightningArgumentParser(ArgumentParser):
|
|||
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}}
|
||||
if isinstance(lr_scheduler_class, tuple):
|
||||
self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs)
|
||||
self.set_choices(nested_key, lr_scheduler_class)
|
||||
else:
|
||||
self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs)
|
||||
self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to)
|
||||
|
||||
def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
argv = sys.argv
|
||||
for k, classes in self._choices.items():
|
||||
if not any(arg.startswith(f"--{k}") for arg in argv):
|
||||
# the key wasn't passed - maybe defined in a config, maybe it's optional
|
||||
continue
|
||||
argv = self._convert_argv_issue_84(classes, k, argv)
|
||||
self._choices.clear()
|
||||
with mock.patch("sys.argv", argv):
|
||||
return super().parse_args(*args, **kwargs)
|
||||
|
||||
def set_choices(self, nested_key: str, classes: Tuple[Type, ...]) -> None:
|
||||
self._choices[nested_key] = classes
|
||||
|
||||
@staticmethod
|
||||
def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]:
|
||||
"""Placeholder for https://github.com/omni-us/jsonargparse/issues/84.
|
||||
|
||||
This should be removed once implemented.
|
||||
"""
|
||||
passed_args, clean_argv = {}, []
|
||||
argv_key = f"--{nested_key}"
|
||||
# get the argv args for this nested key
|
||||
i = 0
|
||||
while i < len(argv):
|
||||
arg = argv[i]
|
||||
if arg.startswith(argv_key):
|
||||
if "=" in arg:
|
||||
key, value = arg.split("=")
|
||||
else:
|
||||
key = arg
|
||||
i += 1
|
||||
value = argv[i]
|
||||
passed_args[key] = value
|
||||
else:
|
||||
clean_argv.append(arg)
|
||||
i += 1
|
||||
# generate the associated config file
|
||||
argv_class = passed_args.pop(argv_key, None)
|
||||
if argv_class is None:
|
||||
# the user passed a config as a str
|
||||
class_path = passed_args[f"{argv_key}.class_path"]
|
||||
init_args_key = f"{argv_key}.init_args"
|
||||
init_args = {k[len(init_args_key) + 1 :]: v for k, v in passed_args.items() if k.startswith(init_args_key)}
|
||||
config = str({"class_path": class_path, "init_args": init_args})
|
||||
elif argv_class.startswith("{"):
|
||||
# the user passed a config as a dict
|
||||
config = argv_class
|
||||
else:
|
||||
# the user passed the shorthand format
|
||||
init_args = {k[len(argv_key) + 1 :]: v for k, v in passed_args.items()} # +1 to account for the period
|
||||
for cls in classes:
|
||||
if cls.__name__ == argv_class:
|
||||
config = str(_global_add_class_path(cls, init_args))
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Could not generate a config for {repr(argv_class)}")
|
||||
return clean_argv + [argv_key, config]
|
||||
|
||||
|
||||
class SaveConfigCallback(Callback):
|
||||
"""Saves a LightningCLI config to the log_dir when training starts.
|
||||
|
@ -328,6 +441,11 @@ class LightningCLI:
|
|||
self.add_default_arguments_to_parser(parser)
|
||||
self.add_core_arguments_to_parser(parser)
|
||||
self.add_arguments_to_parser(parser)
|
||||
# add default optimizer args if necessary
|
||||
if not parser._optimizers: # already added by the user in `add_arguments_to_parser`
|
||||
parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes)
|
||||
if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser`
|
||||
parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes)
|
||||
self.link_optimizers_and_lr_schedulers(parser)
|
||||
|
||||
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
||||
|
|
|
@ -7,6 +7,6 @@ torchtext>=0.7
|
|||
onnx>=1.7.0
|
||||
onnxruntime>=1.3.0
|
||||
hydra-core>=1.0
|
||||
jsonargparse[signatures]>=3.19.0
|
||||
jsonargparse[signatures]>=3.19.3
|
||||
gcsfs>=2021.5.0
|
||||
rich>=10.2.2
|
||||
|
|
|
@ -33,7 +33,15 @@ from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
|||
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.cli import instantiate_class, LightningArgumentParser, LightningCLI, SaveConfigCallback
|
||||
from pytorch_lightning.utilities.cli import (
|
||||
instantiate_class,
|
||||
LightningArgumentParser,
|
||||
LightningCLI,
|
||||
LR_SCHEDULER_REGISTRY,
|
||||
OPTIMIZER_REGISTRY,
|
||||
SaveConfigCallback,
|
||||
)
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
|
||||
from tests.helpers import BoringDataModule, BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
@ -678,12 +686,20 @@ def test_lightning_cli_optimizer_and_lr_scheduler_subclasses(tmpdir):
|
|||
assert cli.trainer.lr_schedulers[0]["scheduler"].step_size == 50
|
||||
|
||||
|
||||
def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(tmpdir):
|
||||
@pytest.mark.parametrize("use_registries", [False, True])
|
||||
def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_registries, tmpdir):
|
||||
class MyLightningCLI(LightningCLI):
|
||||
def add_arguments_to_parser(self, parser):
|
||||
parser.add_optimizer_args(torch.optim.Adam, nested_key="optim1", link_to="model.optim1")
|
||||
parser.add_optimizer_args(
|
||||
OPTIMIZER_REGISTRY.classes if use_registries else torch.optim.Adam,
|
||||
nested_key="optim1",
|
||||
link_to="model.optim1",
|
||||
)
|
||||
parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2")
|
||||
parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR, link_to="model.scheduler")
|
||||
parser.add_lr_scheduler_args(
|
||||
LR_SCHEDULER_REGISTRY.classes if use_registries else torch.optim.lr_scheduler.ExponentialLR,
|
||||
link_to="model.scheduler",
|
||||
)
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def __init__(self, optim1: dict, optim2: dict, scheduler: dict):
|
||||
|
@ -692,20 +708,26 @@ def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(tmpdir):
|
|||
self.optim2 = instantiate_class(self.parameters(), optim2)
|
||||
self.scheduler = instantiate_class(self.optim1, scheduler)
|
||||
|
||||
cli_args = [
|
||||
"fit",
|
||||
f"--trainer.default_root_dir={tmpdir}",
|
||||
"--trainer.max_epochs=1",
|
||||
"--optim2.class_path=torch.optim.SGD",
|
||||
"--optim2.init_args.lr=0.01",
|
||||
"--lr_scheduler.gamma=0.2",
|
||||
]
|
||||
cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", "--lr_scheduler.gamma=0.2"]
|
||||
if use_registries:
|
||||
cli_args += [
|
||||
"--optim1",
|
||||
"Adam",
|
||||
"--optim1.weight_decay",
|
||||
"0.001",
|
||||
"--optim2=SGD",
|
||||
"--optim2.lr=0.01",
|
||||
"--lr_scheduler=ExponentialLR",
|
||||
]
|
||||
else:
|
||||
cli_args += ["--optim2.class_path=torch.optim.SGD", "--optim2.init_args.lr=0.01"]
|
||||
|
||||
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
||||
cli = MyLightningCLI(TestModel)
|
||||
|
||||
assert isinstance(cli.model.optim1, torch.optim.Adam)
|
||||
assert isinstance(cli.model.optim2, torch.optim.SGD)
|
||||
assert cli.model.optim2.param_groups[0]["lr"] == 0.01
|
||||
assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)
|
||||
|
||||
|
||||
|
@ -829,6 +851,190 @@ def test_lightning_cli_run():
|
|||
assert isinstance(cli.model, LightningModule)
|
||||
|
||||
|
||||
@OPTIMIZER_REGISTRY
|
||||
class CustomAdam(torch.optim.Adam):
|
||||
pass
|
||||
|
||||
|
||||
@LR_SCHEDULER_REGISTRY
|
||||
class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
|
||||
pass
|
||||
|
||||
|
||||
def test_registries(tmpdir):
|
||||
assert "SGD" in OPTIMIZER_REGISTRY.names
|
||||
assert "RMSprop" in OPTIMIZER_REGISTRY.names
|
||||
assert "CustomAdam" in OPTIMIZER_REGISTRY.names
|
||||
|
||||
assert "CosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names
|
||||
assert "CosineAnnealingWarmRestarts" in LR_SCHEDULER_REGISTRY.names
|
||||
assert "CustomCosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names
|
||||
|
||||
with pytest.raises(MisconfigurationException, match="is already present in the registry"):
|
||||
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer)
|
||||
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer, override=True)
|
||||
|
||||
|
||||
def test_registries_resolution():
|
||||
"""This test validates registries are used when simplified command line are being used."""
|
||||
cli_args = [
|
||||
"--optimizer",
|
||||
"Adam",
|
||||
"--optimizer.lr",
|
||||
"0.0001",
|
||||
"--lr_scheduler",
|
||||
"StepLR",
|
||||
"--lr_scheduler.step_size=50",
|
||||
]
|
||||
|
||||
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
||||
cli = LightningCLI(BoringModel, run=False)
|
||||
|
||||
optimizers, lr_scheduler = cli.model.configure_optimizers()
|
||||
assert isinstance(optimizers[0], torch.optim.Adam)
|
||||
assert optimizers[0].param_groups[0]["lr"] == 0.0001
|
||||
assert lr_scheduler[0].step_size == 50
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["args", "expected", "nested_key", "registry"],
|
||||
[
|
||||
(
|
||||
["--optimizer", "Adadelta"],
|
||||
{"class_path": "torch.optim.adadelta.Adadelta", "init_args": {}},
|
||||
"optimizer",
|
||||
OPTIMIZER_REGISTRY,
|
||||
),
|
||||
(
|
||||
["--optimizer", "Adadelta", "--optimizer.lr", "10"],
|
||||
{"class_path": "torch.optim.adadelta.Adadelta", "init_args": {"lr": "10"}},
|
||||
"optimizer",
|
||||
OPTIMIZER_REGISTRY,
|
||||
),
|
||||
(
|
||||
["--lr_scheduler", "OneCycleLR"],
|
||||
{"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {}},
|
||||
"lr_scheduler",
|
||||
LR_SCHEDULER_REGISTRY,
|
||||
),
|
||||
(
|
||||
["--lr_scheduler", "OneCycleLR", "--lr_scheduler.anneal_strategy=linear"],
|
||||
{"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {"anneal_strategy": "linear"}},
|
||||
"lr_scheduler",
|
||||
LR_SCHEDULER_REGISTRY,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_argv_transformations_with_optimizers_and_lr_schedulers(args, expected, nested_key, registry):
|
||||
base = ["any.py", "--trainer.max_epochs=1"]
|
||||
argv = base + args
|
||||
new_argv = LightningArgumentParser._convert_argv_issue_84(registry.classes, nested_key, argv)
|
||||
assert new_argv == base + [f"--{nested_key}", str(expected)]
|
||||
|
||||
|
||||
def test_optimizers_and_lr_schedulers_reload(tmpdir):
|
||||
base = ["any.py", "--trainer.max_epochs=1"]
|
||||
input = base + [
|
||||
"--lr_scheduler",
|
||||
"OneCycleLR",
|
||||
"--lr_scheduler.total_steps=10",
|
||||
"--lr_scheduler.max_lr=1",
|
||||
"--optimizer",
|
||||
"Adam",
|
||||
"--optimizer.lr=0.1",
|
||||
]
|
||||
|
||||
# save config
|
||||
out = StringIO()
|
||||
with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit):
|
||||
LightningCLI(BoringModel, run=False)
|
||||
|
||||
# validate yaml
|
||||
yaml_config = out.getvalue()
|
||||
dict_config = yaml.safe_load(yaml_config)
|
||||
assert dict_config["optimizer"]["class_path"] == "torch.optim.adam.Adam"
|
||||
assert dict_config["optimizer"]["init_args"]["lr"] == 0.1
|
||||
assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR"
|
||||
|
||||
# reload config
|
||||
yaml_config_file = tmpdir / "config.yaml"
|
||||
yaml_config_file.write_text(yaml_config, "utf-8")
|
||||
with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]):
|
||||
LightningCLI(BoringModel, run=False)
|
||||
|
||||
|
||||
def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload(tmpdir):
|
||||
class TestLightningCLI(LightningCLI):
|
||||
def __init__(self, *args):
|
||||
super().__init__(*args, run=False)
|
||||
|
||||
def add_arguments_to_parser(self, parser):
|
||||
parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes, nested_key="opt1", link_to="model.opt1_config")
|
||||
parser.add_optimizer_args(
|
||||
(torch.optim.ASGD, torch.optim.SGD), nested_key="opt2", link_to="model.opt2_config"
|
||||
)
|
||||
parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes, link_to="model.sch_config")
|
||||
parser.add_argument("--something", type=str, nargs="+")
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def __init__(self, opt1_config: dict, opt2_config: dict, sch_config: dict):
|
||||
super().__init__()
|
||||
self.opt1_config = opt1_config
|
||||
self.opt2_config = opt2_config
|
||||
self.sch_config = sch_config
|
||||
opt1 = instantiate_class(self.parameters(), opt1_config)
|
||||
assert isinstance(opt1, torch.optim.Adam)
|
||||
opt2 = instantiate_class(self.parameters(), opt2_config)
|
||||
assert isinstance(opt2, torch.optim.ASGD)
|
||||
sch = instantiate_class(opt1, sch_config)
|
||||
assert isinstance(sch, torch.optim.lr_scheduler.OneCycleLR)
|
||||
|
||||
base = ["any.py", "--trainer.max_epochs=1"]
|
||||
input = base + [
|
||||
"--lr_scheduler",
|
||||
"OneCycleLR",
|
||||
"--lr_scheduler.total_steps=10",
|
||||
"--lr_scheduler.max_lr=1",
|
||||
"--opt1",
|
||||
"Adam",
|
||||
"--opt2.lr=0.1",
|
||||
"--opt2",
|
||||
"ASGD",
|
||||
"--lr_scheduler.anneal_strategy=linear",
|
||||
"--something",
|
||||
"a",
|
||||
"b",
|
||||
"c",
|
||||
]
|
||||
|
||||
# save config
|
||||
out = StringIO()
|
||||
with mock.patch("sys.argv", input + ["--print_config"]), redirect_stdout(out), pytest.raises(SystemExit):
|
||||
TestLightningCLI(TestModel)
|
||||
|
||||
# validate yaml
|
||||
yaml_config = out.getvalue()
|
||||
dict_config = yaml.safe_load(yaml_config)
|
||||
assert dict_config["opt1"]["class_path"] == "torch.optim.adam.Adam"
|
||||
assert dict_config["opt2"]["class_path"] == "torch.optim.asgd.ASGD"
|
||||
assert dict_config["opt2"]["init_args"]["lr"] == 0.1
|
||||
assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR"
|
||||
assert dict_config["lr_scheduler"]["init_args"]["anneal_strategy"] == "linear"
|
||||
assert dict_config["something"] == ["a", "b", "c"]
|
||||
|
||||
# reload config
|
||||
yaml_config_file = tmpdir / "config.yaml"
|
||||
yaml_config_file.write_text(yaml_config, "utf-8")
|
||||
with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]):
|
||||
cli = TestLightningCLI(TestModel)
|
||||
|
||||
assert cli.model.opt1_config["class_path"] == "torch.optim.adam.Adam"
|
||||
assert cli.model.opt2_config["class_path"] == "torch.optim.asgd.ASGD"
|
||||
assert cli.model.opt2_config["init_args"]["lr"] == 0.1
|
||||
assert cli.model.sch_config["class_path"] == "torch.optim.lr_scheduler.OneCycleLR"
|
||||
assert cli.model.sch_config["init_args"]["anneal_strategy"] == "linear"
|
||||
|
||||
|
||||
@RunIf(min_python="3.7.3") # bpo-17185: `autospec=True` and `inspect.signature` do not play well
|
||||
def test_lightning_cli_config_with_subcommand():
|
||||
config = {"test": {"trainer": {"limit_test_batches": 1}, "verbose": True, "ckpt_path": "foobar"}}
|
||||
|
|
Loading…
Reference in New Issue