LightningCLI support for optimizers and schedulers via dependency injection (#15869)

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Mauricio Villegas 2022-12-12 16:36:19 +01:00 committed by GitHub
parent 38acba08fc
commit ed52823c3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 141 additions and 47 deletions

View File

@ -217,8 +217,8 @@ If the CLI is implemented as ``LightningCLI(MyMainModel)`` the configuration wou
It is also possible to combine ``subclass_mode_model=True`` and submodules, thereby having two levels of ``class_path``.
Optimizers
^^^^^^^^^^
Fixed optimizer and scheduler
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In some cases, fixing the optimizer and/or learning scheduler might be desired instead of allowing multiple. For this,
you can manually add the arguments for specific classes by subclassing the CLI. The following code snippet shows how to
@ -251,58 +251,88 @@ where the arguments can be passed directly through the command line without spec
$ python trainer.py fit --optimizer.lr=0.01 --lr_scheduler.gamma=0.2
The automatic implementation of ``configure_optimizers`` can be disabled by linking the configuration group. An example
can be when someone wants to add support for multiple optimizers:
Multiple optimizers and schedulers
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
By default, the CLIs support multiple optimizers and/or learning schedulers, automatically implementing
``configure_optimizers``. This behavior can be disabled by providing ``auto_configure_optimizers=False`` on
instantiation of :class:`~pytorch_lightning.cli.LightningCLI`. This would be required for example to support multiple
optimizers, for each selecting a particular optimizer class. Similar to multiple submodules, this can be done via
`dependency injection <https://en.wikipedia.org/wiki/Dependency_injection>`__. Unlike the submodules, it is not possible
to expect an instance of a class, because optimizers require the module's parameters to optimize, which are only
available after instantiation of the module. Learning schedulers are a similar situation, requiring an optimizer
instance. For these cases, dependency injection involves providing a function that instantiates the respective class
when called.
An example of a model that uses two optimizers is the following:
.. code-block:: python
from pytorch_lightning.cli import instantiate_class
from typing import Iterable
from torch.optim import Optimizer
OptimizerCallable = Callable[[Iterable], Optimizer]
class MyModel(LightningModule):
def __init__(self, optimizer1_init: dict, optimizer2_init: dict):
def __init__(self, optimizer1: OptimizerCallable, optimizer2: OptimizerCallable):
super().__init__()
self.optimizer1_init = optimizer1_init
self.optimizer2_init = optimizer2_init
self.optimizer1 = optimizer1
self.optimizer2 = optimizer2
def configure_optimizers(self):
optimizer1 = instantiate_class(self.parameters(), self.optimizer1_init)
optimizer2 = instantiate_class(self.parameters(), self.optimizer2_init)
optimizer1 = self.optimizer1(self.parameters())
optimizer2 = self.optimizer2(self.parameters())
return [optimizer1, optimizer2]
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(nested_key="optimizer1", link_to="model.optimizer1_init")
parser.add_optimizer_args(nested_key="optimizer2", link_to="model.optimizer2_init")
cli = MyLightningCLI(MyModel, auto_configure_optimizers=False)
cli = MyLightningCLI(MyModel)
The value given to ``optimizer*_init`` will always be a dictionary including ``class_path`` and ``init_args`` entries.
The function :func:`~pytorch_lightning.cli.instantiate_class` takes care of importing the class defined in
``class_path`` and instantiating it using some positional arguments, in this case ``self.parameters()``, and the
``init_args``. Any number of optimizers and learning rate schedulers can be added when using ``link_to``.
With shorthand notation:
Note the type ``Callable[[Iterable], Optimizer]``, which denotes a function that receives a singe argument, some
learnable parameters, and returns an optimizer instance. With this, from the command line it is possible to select the
class and init arguments for each of the optimizers, as follows:
.. code-block:: bash
$ python trainer.py fit \
--optimizer1=Adam \
--optimizer1.lr=0.01 \
--optimizer2=AdamW \
--optimizer2.lr=0.0001
--model.optimizer1=Adam \
--model.optimizer1.lr=0.01 \
--model.optimizer2=AdamW \
--model.optimizer2.lr=0.0001
You can also pass the class path directly, for example, if the optimizer hasn't been imported:
In the example above, the ``OptimizerCallable`` type alias was created to illustrate what the type hint means. For
convenience, this type alias and one for learning schedulers is available in the ``cli`` module. An example of a model
that uses dependency injection for an optimizer and a learning scheduler is:
.. code-block:: bash
.. code-block:: python
$ python trainer.py fit \
--optimizer1=torch.optim.Adam \
--optimizer1.lr=0.01 \
--optimizer2=torch.optim.AdamW \
--optimizer2.lr=0.0001
from pytorch_lightning.cli import OptimizerCallable, LRSchedulerCallable, LightningCLI
class MyModel(LightningModule):
def __init__(
self,
optimizer: OptimizerCallable = torch.optim.Adam,
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__()
self.optimizer = optimizer
self.scheduler = scheduler
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters())
scheduler = self.scheduler(self.parameters())
return {"optimizer": optimizer, "lr_scheduler": scheduler}
cli = MyLightningCLI(MyModel, auto_configure_optimizers=False)
Note that for this example, classes are used as defaults. This is compatible with the type hints, since they are also
callables that receive the same first argument and return an instance of the class. Classes that have more than one
required argument will not work as default. For these cases a lambda function can be used, e.g. ``optimizer:
OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01)``.
Run from Python

View File

@ -5,5 +5,5 @@
matplotlib>3.1, <3.6.2
omegaconf>=2.0.5, <2.3.0
hydra-core>=1.0.5, <1.3.0
jsonargparse[signatures]>=4.17.0, <4.18.0
jsonargparse[signatures]>=4.18.0, <4.19.0
rich>=10.14.0, !=10.15.0.a, <13.0.0

View File

@ -30,9 +30,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814))
- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826))
- Added `LightningCLI` support for optimizer and learning schedulers via callable type dependency injection ([#15869](https://github.com/Lightning-AI/lightning/pull/15869))
- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826))
- Added the option to set `DDPFullyShardedNativeStrategy(cpu_offload=True|False)` via bool instead of needing to pass a configufation object ([#15832](https://github.com/Lightning-AI/lightning/pull/15832))

View File

@ -15,7 +15,7 @@ import os
import sys
from functools import partial, update_wrapper
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
import torch
from lightning_utilities.core.imports import RequirementCache
@ -24,6 +24,7 @@ from torch.optim import Optimizer
import pytorch_lightning as pl
from lightning_lite.utilities.cloud_io import get_filesystem
from lightning_lite.utilities.types import _TORCH_LRSCHEDULER
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
@ -49,9 +50,6 @@ else:
locals()["Namespace"] = object
ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]]
class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
super().__init__(optimizer, *args, **kwargs)
@ -59,9 +57,15 @@ class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
# LightningCLI requires the ReduceLROnPlateau defined here, thus it shouldn't accept the one from pytorch:
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau]
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[ReduceLROnPlateau]]
LRSchedulerTypeTuple = (_TORCH_LRSCHEDULER, ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau]
LRSchedulerType = Union[Type[_TORCH_LRSCHEDULER], Type[ReduceLROnPlateau]]
# Type aliases intended for convenience of CLI developers
ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]]
OptimizerCallable = Callable[[Iterable], Optimizer]
LRSchedulerCallable = Callable[[Optimizer], Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau]]
class LightningArgumentParser(ArgumentParser):
@ -274,6 +278,7 @@ class LightningCLI:
subclass_mode_data: bool = False,
args: ArgsType = None,
run: bool = True,
auto_configure_optimizers: bool = True,
auto_registry: bool = False,
**kwargs: Any, # Remove with deprecations of v1.10
) -> None:
@ -326,6 +331,7 @@ class LightningCLI:
self.trainer_defaults = trainer_defaults or {}
self.seed_everything_default = seed_everything_default
self.parser_kwargs = parser_kwargs or {} # type: ignore[var-annotated] # github.com/python/mypy/issues/6463
self.auto_configure_optimizers = auto_configure_optimizers
self._handle_deprecated_params(kwargs)
@ -447,10 +453,11 @@ class LightningCLI:
self.add_core_arguments_to_parser(parser)
self.add_arguments_to_parser(parser)
# add default optimizer args if necessary
if not parser._optimizers: # already added by the user in `add_arguments_to_parser`
parser.add_optimizer_args((Optimizer,))
if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser`
parser.add_lr_scheduler_args(LRSchedulerTypeTuple)
if self.auto_configure_optimizers:
if not parser._optimizers: # already added by the user in `add_arguments_to_parser`
parser.add_optimizer_args((Optimizer,))
if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser`
parser.add_lr_scheduler_args(LRSchedulerTypeTuple)
self.link_optimizers_and_lr_schedulers(parser)
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
@ -602,6 +609,9 @@ class LightningCLI:
def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None:
"""Overrides the model's :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers` method
if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'."""
if not self.auto_configure_optimizers:
return
parser = self._parser(subcommand)
def get_automatic(

View File

@ -36,7 +36,9 @@ from pytorch_lightning.cli import (
instantiate_class,
LightningArgumentParser,
LightningCLI,
LRSchedulerCallable,
LRSchedulerTypeTuple,
OptimizerCallable,
SaveConfigCallback,
)
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
@ -706,6 +708,56 @@ def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_generic_base
assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)
def test_lightning_cli_optimizers_and_lr_scheduler_with_callable_type():
class TestModel(BoringModel):
def __init__(
self,
optim1: OptimizerCallable = torch.optim.Adam,
optim2: OptimizerCallable = torch.optim.Adagrad,
scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR,
):
super().__init__()
self.optim1 = optim1
self.optim2 = optim2
self.scheduler = scheduler
def configure_optimizers(self):
optim1 = self.optim1(self.parameters())
optim2 = self.optim2(self.parameters())
scheduler = self.scheduler(optim2)
return (
{"optimizer": optim1},
{"optimizer": optim2, "lr_scheduler": scheduler},
)
out = StringIO()
with mock.patch("sys.argv", ["any.py", "-h"]), redirect_stdout(out), pytest.raises(SystemExit):
LightningCLI(TestModel, run=False, auto_configure_optimizers=False)
out = out.getvalue()
assert "--optimizer" not in out
assert "--lr_scheduler" not in out
assert "--model.optim1" in out
assert "--model.optim2" in out
assert "--model.scheduler" in out
cli_args = [
"--model.optim1=Adagrad",
"--model.optim2=SGD",
"--model.optim2.lr=0.007",
"--model.scheduler=ExponentialLR",
"--model.scheduler.gamma=0.3",
]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(TestModel, run=False, auto_configure_optimizers=False)
init = cli.model.configure_optimizers()
assert isinstance(init[0]["optimizer"], torch.optim.Adagrad)
assert isinstance(init[1]["optimizer"], torch.optim.SGD)
assert isinstance(init[1]["lr_scheduler"], torch.optim.lr_scheduler.ExponentialLR)
assert init[1]["optimizer"].param_groups[0]["lr"] == 0.007
assert init[1]["lr_scheduler"].gamma == 0.3
@pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn])
def test_lightning_cli_trainer_fn(fn):
class TestCLI(LightningCLI):