LightningCLI support for optimizers and schedulers via dependency injection (#15869)
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
38acba08fc
commit
ed52823c3f
|
@ -217,8 +217,8 @@ If the CLI is implemented as ``LightningCLI(MyMainModel)`` the configuration wou
|
|||
It is also possible to combine ``subclass_mode_model=True`` and submodules, thereby having two levels of ``class_path``.
|
||||
|
||||
|
||||
Optimizers
|
||||
^^^^^^^^^^
|
||||
Fixed optimizer and scheduler
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
In some cases, fixing the optimizer and/or learning scheduler might be desired instead of allowing multiple. For this,
|
||||
you can manually add the arguments for specific classes by subclassing the CLI. The following code snippet shows how to
|
||||
|
@ -251,58 +251,88 @@ where the arguments can be passed directly through the command line without spec
|
|||
|
||||
$ python trainer.py fit --optimizer.lr=0.01 --lr_scheduler.gamma=0.2
|
||||
|
||||
The automatic implementation of ``configure_optimizers`` can be disabled by linking the configuration group. An example
|
||||
can be when someone wants to add support for multiple optimizers:
|
||||
|
||||
Multiple optimizers and schedulers
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
By default, the CLIs support multiple optimizers and/or learning schedulers, automatically implementing
|
||||
``configure_optimizers``. This behavior can be disabled by providing ``auto_configure_optimizers=False`` on
|
||||
instantiation of :class:`~pytorch_lightning.cli.LightningCLI`. This would be required for example to support multiple
|
||||
optimizers, for each selecting a particular optimizer class. Similar to multiple submodules, this can be done via
|
||||
`dependency injection <https://en.wikipedia.org/wiki/Dependency_injection>`__. Unlike the submodules, it is not possible
|
||||
to expect an instance of a class, because optimizers require the module's parameters to optimize, which are only
|
||||
available after instantiation of the module. Learning schedulers are a similar situation, requiring an optimizer
|
||||
instance. For these cases, dependency injection involves providing a function that instantiates the respective class
|
||||
when called.
|
||||
|
||||
An example of a model that uses two optimizers is the following:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning.cli import instantiate_class
|
||||
from typing import Iterable
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
OptimizerCallable = Callable[[Iterable], Optimizer]
|
||||
|
||||
|
||||
class MyModel(LightningModule):
|
||||
def __init__(self, optimizer1_init: dict, optimizer2_init: dict):
|
||||
def __init__(self, optimizer1: OptimizerCallable, optimizer2: OptimizerCallable):
|
||||
super().__init__()
|
||||
self.optimizer1_init = optimizer1_init
|
||||
self.optimizer2_init = optimizer2_init
|
||||
self.optimizer1 = optimizer1
|
||||
self.optimizer2 = optimizer2
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer1 = instantiate_class(self.parameters(), self.optimizer1_init)
|
||||
optimizer2 = instantiate_class(self.parameters(), self.optimizer2_init)
|
||||
optimizer1 = self.optimizer1(self.parameters())
|
||||
optimizer2 = self.optimizer2(self.parameters())
|
||||
return [optimizer1, optimizer2]
|
||||
|
||||
|
||||
class MyLightningCLI(LightningCLI):
|
||||
def add_arguments_to_parser(self, parser):
|
||||
parser.add_optimizer_args(nested_key="optimizer1", link_to="model.optimizer1_init")
|
||||
parser.add_optimizer_args(nested_key="optimizer2", link_to="model.optimizer2_init")
|
||||
cli = MyLightningCLI(MyModel, auto_configure_optimizers=False)
|
||||
|
||||
|
||||
cli = MyLightningCLI(MyModel)
|
||||
|
||||
The value given to ``optimizer*_init`` will always be a dictionary including ``class_path`` and ``init_args`` entries.
|
||||
The function :func:`~pytorch_lightning.cli.instantiate_class` takes care of importing the class defined in
|
||||
``class_path`` and instantiating it using some positional arguments, in this case ``self.parameters()``, and the
|
||||
``init_args``. Any number of optimizers and learning rate schedulers can be added when using ``link_to``.
|
||||
|
||||
With shorthand notation:
|
||||
Note the type ``Callable[[Iterable], Optimizer]``, which denotes a function that receives a singe argument, some
|
||||
learnable parameters, and returns an optimizer instance. With this, from the command line it is possible to select the
|
||||
class and init arguments for each of the optimizers, as follows:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python trainer.py fit \
|
||||
--optimizer1=Adam \
|
||||
--optimizer1.lr=0.01 \
|
||||
--optimizer2=AdamW \
|
||||
--optimizer2.lr=0.0001
|
||||
--model.optimizer1=Adam \
|
||||
--model.optimizer1.lr=0.01 \
|
||||
--model.optimizer2=AdamW \
|
||||
--model.optimizer2.lr=0.0001
|
||||
|
||||
You can also pass the class path directly, for example, if the optimizer hasn't been imported:
|
||||
In the example above, the ``OptimizerCallable`` type alias was created to illustrate what the type hint means. For
|
||||
convenience, this type alias and one for learning schedulers is available in the ``cli`` module. An example of a model
|
||||
that uses dependency injection for an optimizer and a learning scheduler is:
|
||||
|
||||
.. code-block:: bash
|
||||
.. code-block:: python
|
||||
|
||||
$ python trainer.py fit \
|
||||
--optimizer1=torch.optim.Adam \
|
||||
--optimizer1.lr=0.01 \
|
||||
--optimizer2=torch.optim.AdamW \
|
||||
--optimizer2.lr=0.0001
|
||||
from pytorch_lightning.cli import OptimizerCallable, LRSchedulerCallable, LightningCLI
|
||||
|
||||
|
||||
class MyModel(LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: OptimizerCallable = torch.optim.Adam,
|
||||
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
|
||||
):
|
||||
super().__init__()
|
||||
self.optimizer = optimizer
|
||||
self.scheduler = scheduler
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = self.optimizer(self.parameters())
|
||||
scheduler = self.scheduler(self.parameters())
|
||||
return {"optimizer": optimizer, "lr_scheduler": scheduler}
|
||||
|
||||
|
||||
cli = MyLightningCLI(MyModel, auto_configure_optimizers=False)
|
||||
|
||||
Note that for this example, classes are used as defaults. This is compatible with the type hints, since they are also
|
||||
callables that receive the same first argument and return an instance of the class. Classes that have more than one
|
||||
required argument will not work as default. For these cases a lambda function can be used, e.g. ``optimizer:
|
||||
OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01)``.
|
||||
|
||||
|
||||
Run from Python
|
||||
|
|
|
@ -5,5 +5,5 @@
|
|||
matplotlib>3.1, <3.6.2
|
||||
omegaconf>=2.0.5, <2.3.0
|
||||
hydra-core>=1.0.5, <1.3.0
|
||||
jsonargparse[signatures]>=4.17.0, <4.18.0
|
||||
jsonargparse[signatures]>=4.18.0, <4.19.0
|
||||
rich>=10.14.0, !=10.15.0.a, <13.0.0
|
||||
|
|
|
@ -30,9 +30,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814))
|
||||
|
||||
|
||||
- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826))
|
||||
- Added `LightningCLI` support for optimizer and learning schedulers via callable type dependency injection ([#15869](https://github.com/Lightning-AI/lightning/pull/15869))
|
||||
|
||||
|
||||
- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826))
|
||||
|
||||
- Added the option to set `DDPFullyShardedNativeStrategy(cpu_offload=True|False)` via bool instead of needing to pass a configufation object ([#15832](https://github.com/Lightning-AI/lightning/pull/15832))
|
||||
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ import os
|
|||
import sys
|
||||
from functools import partial, update_wrapper
|
||||
from types import MethodType
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
|
@ -24,6 +24,7 @@ from torch.optim import Optimizer
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from lightning_lite.utilities.cloud_io import get_filesystem
|
||||
from lightning_lite.utilities.types import _TORCH_LRSCHEDULER
|
||||
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
|
@ -49,9 +50,6 @@ else:
|
|||
locals()["Namespace"] = object
|
||||
|
||||
|
||||
ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]]
|
||||
|
||||
|
||||
class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||
def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(optimizer, *args, **kwargs)
|
||||
|
@ -59,9 +57,15 @@ class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|||
|
||||
|
||||
# LightningCLI requires the ReduceLROnPlateau defined here, thus it shouldn't accept the one from pytorch:
|
||||
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau)
|
||||
LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau]
|
||||
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[ReduceLROnPlateau]]
|
||||
LRSchedulerTypeTuple = (_TORCH_LRSCHEDULER, ReduceLROnPlateau)
|
||||
LRSchedulerTypeUnion = Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau]
|
||||
LRSchedulerType = Union[Type[_TORCH_LRSCHEDULER], Type[ReduceLROnPlateau]]
|
||||
|
||||
|
||||
# Type aliases intended for convenience of CLI developers
|
||||
ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]]
|
||||
OptimizerCallable = Callable[[Iterable], Optimizer]
|
||||
LRSchedulerCallable = Callable[[Optimizer], Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau]]
|
||||
|
||||
|
||||
class LightningArgumentParser(ArgumentParser):
|
||||
|
@ -274,6 +278,7 @@ class LightningCLI:
|
|||
subclass_mode_data: bool = False,
|
||||
args: ArgsType = None,
|
||||
run: bool = True,
|
||||
auto_configure_optimizers: bool = True,
|
||||
auto_registry: bool = False,
|
||||
**kwargs: Any, # Remove with deprecations of v1.10
|
||||
) -> None:
|
||||
|
@ -326,6 +331,7 @@ class LightningCLI:
|
|||
self.trainer_defaults = trainer_defaults or {}
|
||||
self.seed_everything_default = seed_everything_default
|
||||
self.parser_kwargs = parser_kwargs or {} # type: ignore[var-annotated] # github.com/python/mypy/issues/6463
|
||||
self.auto_configure_optimizers = auto_configure_optimizers
|
||||
|
||||
self._handle_deprecated_params(kwargs)
|
||||
|
||||
|
@ -447,10 +453,11 @@ class LightningCLI:
|
|||
self.add_core_arguments_to_parser(parser)
|
||||
self.add_arguments_to_parser(parser)
|
||||
# add default optimizer args if necessary
|
||||
if not parser._optimizers: # already added by the user in `add_arguments_to_parser`
|
||||
parser.add_optimizer_args((Optimizer,))
|
||||
if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser`
|
||||
parser.add_lr_scheduler_args(LRSchedulerTypeTuple)
|
||||
if self.auto_configure_optimizers:
|
||||
if not parser._optimizers: # already added by the user in `add_arguments_to_parser`
|
||||
parser.add_optimizer_args((Optimizer,))
|
||||
if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser`
|
||||
parser.add_lr_scheduler_args(LRSchedulerTypeTuple)
|
||||
self.link_optimizers_and_lr_schedulers(parser)
|
||||
|
||||
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
||||
|
@ -602,6 +609,9 @@ class LightningCLI:
|
|||
def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None:
|
||||
"""Overrides the model's :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers` method
|
||||
if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'."""
|
||||
if not self.auto_configure_optimizers:
|
||||
return
|
||||
|
||||
parser = self._parser(subcommand)
|
||||
|
||||
def get_automatic(
|
||||
|
|
|
@ -36,7 +36,9 @@ from pytorch_lightning.cli import (
|
|||
instantiate_class,
|
||||
LightningArgumentParser,
|
||||
LightningCLI,
|
||||
LRSchedulerCallable,
|
||||
LRSchedulerTypeTuple,
|
||||
OptimizerCallable,
|
||||
SaveConfigCallback,
|
||||
)
|
||||
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
|
||||
|
@ -706,6 +708,56 @@ def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_generic_base
|
|||
assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)
|
||||
|
||||
|
||||
def test_lightning_cli_optimizers_and_lr_scheduler_with_callable_type():
|
||||
class TestModel(BoringModel):
|
||||
def __init__(
|
||||
self,
|
||||
optim1: OptimizerCallable = torch.optim.Adam,
|
||||
optim2: OptimizerCallable = torch.optim.Adagrad,
|
||||
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
|
||||
):
|
||||
super().__init__()
|
||||
self.optim1 = optim1
|
||||
self.optim2 = optim2
|
||||
self.scheduler = scheduler
|
||||
|
||||
def configure_optimizers(self):
|
||||
optim1 = self.optim1(self.parameters())
|
||||
optim2 = self.optim2(self.parameters())
|
||||
scheduler = self.scheduler(optim2)
|
||||
return (
|
||||
{"optimizer": optim1},
|
||||
{"optimizer": optim2, "lr_scheduler": scheduler},
|
||||
)
|
||||
|
||||
out = StringIO()
|
||||
with mock.patch("sys.argv", ["any.py", "-h"]), redirect_stdout(out), pytest.raises(SystemExit):
|
||||
LightningCLI(TestModel, run=False, auto_configure_optimizers=False)
|
||||
out = out.getvalue()
|
||||
assert "--optimizer" not in out
|
||||
assert "--lr_scheduler" not in out
|
||||
assert "--model.optim1" in out
|
||||
assert "--model.optim2" in out
|
||||
assert "--model.scheduler" in out
|
||||
|
||||
cli_args = [
|
||||
"--model.optim1=Adagrad",
|
||||
"--model.optim2=SGD",
|
||||
"--model.optim2.lr=0.007",
|
||||
"--model.scheduler=ExponentialLR",
|
||||
"--model.scheduler.gamma=0.3",
|
||||
]
|
||||
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
||||
cli = LightningCLI(TestModel, run=False, auto_configure_optimizers=False)
|
||||
|
||||
init = cli.model.configure_optimizers()
|
||||
assert isinstance(init[0]["optimizer"], torch.optim.Adagrad)
|
||||
assert isinstance(init[1]["optimizer"], torch.optim.SGD)
|
||||
assert isinstance(init[1]["lr_scheduler"], torch.optim.lr_scheduler.ExponentialLR)
|
||||
assert init[1]["optimizer"].param_groups[0]["lr"] == 0.007
|
||||
assert init[1]["lr_scheduler"].gamma == 0.3
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn])
|
||||
def test_lightning_cli_trainer_fn(fn):
|
||||
class TestCLI(LightningCLI):
|
||||
|
|
Loading…
Reference in New Issue