diff --git a/docs/source-pytorch/cli/lightning_cli_advanced_3.rst b/docs/source-pytorch/cli/lightning_cli_advanced_3.rst index 38fa0662a0..629d785662 100644 --- a/docs/source-pytorch/cli/lightning_cli_advanced_3.rst +++ b/docs/source-pytorch/cli/lightning_cli_advanced_3.rst @@ -217,8 +217,8 @@ If the CLI is implemented as ``LightningCLI(MyMainModel)`` the configuration wou It is also possible to combine ``subclass_mode_model=True`` and submodules, thereby having two levels of ``class_path``. -Optimizers -^^^^^^^^^^ +Fixed optimizer and scheduler +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ In some cases, fixing the optimizer and/or learning scheduler might be desired instead of allowing multiple. For this, you can manually add the arguments for specific classes by subclassing the CLI. The following code snippet shows how to @@ -251,58 +251,88 @@ where the arguments can be passed directly through the command line without spec $ python trainer.py fit --optimizer.lr=0.01 --lr_scheduler.gamma=0.2 -The automatic implementation of ``configure_optimizers`` can be disabled by linking the configuration group. An example -can be when someone wants to add support for multiple optimizers: + +Multiple optimizers and schedulers +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +By default, the CLIs support multiple optimizers and/or learning schedulers, automatically implementing +``configure_optimizers``. This behavior can be disabled by providing ``auto_configure_optimizers=False`` on +instantiation of :class:`~pytorch_lightning.cli.LightningCLI`. This would be required for example to support multiple +optimizers, for each selecting a particular optimizer class. Similar to multiple submodules, this can be done via +`dependency injection `__. Unlike the submodules, it is not possible +to expect an instance of a class, because optimizers require the module's parameters to optimize, which are only +available after instantiation of the module. Learning schedulers are a similar situation, requiring an optimizer +instance. For these cases, dependency injection involves providing a function that instantiates the respective class +when called. + +An example of a model that uses two optimizers is the following: .. code-block:: python - from pytorch_lightning.cli import instantiate_class + from typing import Iterable + from torch.optim import Optimizer + + + OptimizerCallable = Callable[[Iterable], Optimizer] class MyModel(LightningModule): - def __init__(self, optimizer1_init: dict, optimizer2_init: dict): + def __init__(self, optimizer1: OptimizerCallable, optimizer2: OptimizerCallable): super().__init__() - self.optimizer1_init = optimizer1_init - self.optimizer2_init = optimizer2_init + self.optimizer1 = optimizer1 + self.optimizer2 = optimizer2 def configure_optimizers(self): - optimizer1 = instantiate_class(self.parameters(), self.optimizer1_init) - optimizer2 = instantiate_class(self.parameters(), self.optimizer2_init) + optimizer1 = self.optimizer1(self.parameters()) + optimizer2 = self.optimizer2(self.parameters()) return [optimizer1, optimizer2] - class MyLightningCLI(LightningCLI): - def add_arguments_to_parser(self, parser): - parser.add_optimizer_args(nested_key="optimizer1", link_to="model.optimizer1_init") - parser.add_optimizer_args(nested_key="optimizer2", link_to="model.optimizer2_init") + cli = MyLightningCLI(MyModel, auto_configure_optimizers=False) - - cli = MyLightningCLI(MyModel) - -The value given to ``optimizer*_init`` will always be a dictionary including ``class_path`` and ``init_args`` entries. -The function :func:`~pytorch_lightning.cli.instantiate_class` takes care of importing the class defined in -``class_path`` and instantiating it using some positional arguments, in this case ``self.parameters()``, and the -``init_args``. Any number of optimizers and learning rate schedulers can be added when using ``link_to``. - -With shorthand notation: +Note the type ``Callable[[Iterable], Optimizer]``, which denotes a function that receives a singe argument, some +learnable parameters, and returns an optimizer instance. With this, from the command line it is possible to select the +class and init arguments for each of the optimizers, as follows: .. code-block:: bash $ python trainer.py fit \ - --optimizer1=Adam \ - --optimizer1.lr=0.01 \ - --optimizer2=AdamW \ - --optimizer2.lr=0.0001 + --model.optimizer1=Adam \ + --model.optimizer1.lr=0.01 \ + --model.optimizer2=AdamW \ + --model.optimizer2.lr=0.0001 -You can also pass the class path directly, for example, if the optimizer hasn't been imported: +In the example above, the ``OptimizerCallable`` type alias was created to illustrate what the type hint means. For +convenience, this type alias and one for learning schedulers is available in the ``cli`` module. An example of a model +that uses dependency injection for an optimizer and a learning scheduler is: -.. code-block:: bash +.. code-block:: python - $ python trainer.py fit \ - --optimizer1=torch.optim.Adam \ - --optimizer1.lr=0.01 \ - --optimizer2=torch.optim.AdamW \ - --optimizer2.lr=0.0001 + from pytorch_lightning.cli import OptimizerCallable, LRSchedulerCallable, LightningCLI + + + class MyModel(LightningModule): + def __init__( + self, + optimizer: OptimizerCallable = torch.optim.Adam, + scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR, + ): + super().__init__() + self.optimizer = optimizer + self.scheduler = scheduler + + def configure_optimizers(self): + optimizer = self.optimizer(self.parameters()) + scheduler = self.scheduler(self.parameters()) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + + + cli = MyLightningCLI(MyModel, auto_configure_optimizers=False) + +Note that for this example, classes are used as defaults. This is compatible with the type hints, since they are also +callables that receive the same first argument and return an instance of the class. Classes that have more than one +required argument will not work as default. For these cases a lambda function can be used, e.g. ``optimizer: +OptimizerCallable = lambda p: torch.optim.SGD(p, lr=0.01)``. Run from Python diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt index eeea9a7f40..bf4b6050d8 100644 --- a/requirements/pytorch/extra.txt +++ b/requirements/pytorch/extra.txt @@ -5,5 +5,5 @@ matplotlib>3.1, <3.6.2 omegaconf>=2.0.5, <2.3.0 hydra-core>=1.0.5, <1.3.0 -jsonargparse[signatures]>=4.17.0, <4.18.0 +jsonargparse[signatures]>=4.18.0, <4.19.0 rich>=10.14.0, !=10.15.0.a, <13.0.0 diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 814e1ddf5d..60afab1d0b 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -30,9 +30,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a warning when `self.log(..., logger=True)` is called without a configured logger ([#15814](https://github.com/Lightning-AI/lightning/pull/15814)) -- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826)) +- Added `LightningCLI` support for optimizer and learning schedulers via callable type dependency injection ([#15869](https://github.com/Lightning-AI/lightning/pull/15869)) +- Added support for activation checkpointing for the `DDPFullyShardedNativeStrategy` strategy ([#15826](https://github.com/Lightning-AI/lightning/pull/15826)) + - Added the option to set `DDPFullyShardedNativeStrategy(cpu_offload=True|False)` via bool instead of needing to pass a configufation object ([#15832](https://github.com/Lightning-AI/lightning/pull/15832)) diff --git a/src/pytorch_lightning/cli.py b/src/pytorch_lightning/cli.py index 95822a522e..54871e6173 100644 --- a/src/pytorch_lightning/cli.py +++ b/src/pytorch_lightning/cli.py @@ -15,7 +15,7 @@ import os import sys from functools import partial, update_wrapper from types import MethodType -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union import torch from lightning_utilities.core.imports import RequirementCache @@ -24,6 +24,7 @@ from torch.optim import Optimizer import pytorch_lightning as pl from lightning_lite.utilities.cloud_io import get_filesystem +from lightning_lite.utilities.types import _TORCH_LRSCHEDULER from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -49,9 +50,6 @@ else: locals()["Namespace"] = object -ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]] - - class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None: super().__init__(optimizer, *args, **kwargs) @@ -59,9 +57,15 @@ class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): # LightningCLI requires the ReduceLROnPlateau defined here, thus it shouldn't accept the one from pytorch: -LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau) -LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau] -LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[ReduceLROnPlateau]] +LRSchedulerTypeTuple = (_TORCH_LRSCHEDULER, ReduceLROnPlateau) +LRSchedulerTypeUnion = Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau] +LRSchedulerType = Union[Type[_TORCH_LRSCHEDULER], Type[ReduceLROnPlateau]] + + +# Type aliases intended for convenience of CLI developers +ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]] +OptimizerCallable = Callable[[Iterable], Optimizer] +LRSchedulerCallable = Callable[[Optimizer], Union[_TORCH_LRSCHEDULER, ReduceLROnPlateau]] class LightningArgumentParser(ArgumentParser): @@ -274,6 +278,7 @@ class LightningCLI: subclass_mode_data: bool = False, args: ArgsType = None, run: bool = True, + auto_configure_optimizers: bool = True, auto_registry: bool = False, **kwargs: Any, # Remove with deprecations of v1.10 ) -> None: @@ -326,6 +331,7 @@ class LightningCLI: self.trainer_defaults = trainer_defaults or {} self.seed_everything_default = seed_everything_default self.parser_kwargs = parser_kwargs or {} # type: ignore[var-annotated] # github.com/python/mypy/issues/6463 + self.auto_configure_optimizers = auto_configure_optimizers self._handle_deprecated_params(kwargs) @@ -447,10 +453,11 @@ class LightningCLI: self.add_core_arguments_to_parser(parser) self.add_arguments_to_parser(parser) # add default optimizer args if necessary - if not parser._optimizers: # already added by the user in `add_arguments_to_parser` - parser.add_optimizer_args((Optimizer,)) - if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser` - parser.add_lr_scheduler_args(LRSchedulerTypeTuple) + if self.auto_configure_optimizers: + if not parser._optimizers: # already added by the user in `add_arguments_to_parser` + parser.add_optimizer_args((Optimizer,)) + if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser` + parser.add_lr_scheduler_args(LRSchedulerTypeTuple) self.link_optimizers_and_lr_schedulers(parser) def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: @@ -602,6 +609,9 @@ class LightningCLI: def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None: """Overrides the model's :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers` method if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'.""" + if not self.auto_configure_optimizers: + return + parser = self._parser(subcommand) def get_automatic( diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index f80d32f4fb..43d2d70fa7 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -36,7 +36,9 @@ from pytorch_lightning.cli import ( instantiate_class, LightningArgumentParser, LightningCLI, + LRSchedulerCallable, LRSchedulerTypeTuple, + OptimizerCallable, SaveConfigCallback, ) from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel @@ -706,6 +708,56 @@ def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_generic_base assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR) +def test_lightning_cli_optimizers_and_lr_scheduler_with_callable_type(): + class TestModel(BoringModel): + def __init__( + self, + optim1: OptimizerCallable = torch.optim.Adam, + optim2: OptimizerCallable = torch.optim.Adagrad, + scheduler: LRSchedulerCallable = torch.optim.lr_scheduler.ConstantLR, + ): + super().__init__() + self.optim1 = optim1 + self.optim2 = optim2 + self.scheduler = scheduler + + def configure_optimizers(self): + optim1 = self.optim1(self.parameters()) + optim2 = self.optim2(self.parameters()) + scheduler = self.scheduler(optim2) + return ( + {"optimizer": optim1}, + {"optimizer": optim2, "lr_scheduler": scheduler}, + ) + + out = StringIO() + with mock.patch("sys.argv", ["any.py", "-h"]), redirect_stdout(out), pytest.raises(SystemExit): + LightningCLI(TestModel, run=False, auto_configure_optimizers=False) + out = out.getvalue() + assert "--optimizer" not in out + assert "--lr_scheduler" not in out + assert "--model.optim1" in out + assert "--model.optim2" in out + assert "--model.scheduler" in out + + cli_args = [ + "--model.optim1=Adagrad", + "--model.optim2=SGD", + "--model.optim2.lr=0.007", + "--model.scheduler=ExponentialLR", + "--model.scheduler.gamma=0.3", + ] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = LightningCLI(TestModel, run=False, auto_configure_optimizers=False) + + init = cli.model.configure_optimizers() + assert isinstance(init[0]["optimizer"], torch.optim.Adagrad) + assert isinstance(init[1]["optimizer"], torch.optim.SGD) + assert isinstance(init[1]["lr_scheduler"], torch.optim.lr_scheduler.ExponentialLR) + assert init[1]["optimizer"].param_groups[0]["lr"] == 0.007 + assert init[1]["lr_scheduler"].gamma == 0.3 + + @pytest.mark.parametrize("fn", [fn.value for fn in TrainerFn]) def test_lightning_cli_trainer_fn(fn): class TestCLI(LightningCLI):