Deprecate CLI registries and update documentation (#13221)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
ad87d2cad0
commit
0ae9627bf8
|
@ -130,6 +130,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Deprecated `Trainer.reset_train_val_dataloaders()` in favor of `Trainer.reset_{train,val}_dataloader` ([#12184](https://github.com/PyTorchLightning/pytorch-lightning/pull/12184))
|
||||
|
||||
|
||||
- Deprecated LightningCLI's registries in favor of importing the respective package ([#13221](https://github.com/PyTorchLightning/pytorch-lightning/pull/13221))
|
||||
|
||||
|
||||
### Removed
|
||||
|
||||
- Removed the deprecated `Logger.close` method ([#13149](https://github.com/PyTorchLightning/pytorch-lightning/pull/13149))
|
||||
|
|
|
@ -27,7 +27,7 @@ Basic use
|
|||
|
||||
.. displayitem::
|
||||
:header: 2: Mix models and datasets
|
||||
:description: Register models, datasets, optimizers and learning rate schedulers
|
||||
:description: Support multiple models, datasets, optimizers and learning rate schedulers
|
||||
:col_css: col-md-4
|
||||
:button_link: lightning_cli_intermediate_2.html
|
||||
:height: 150
|
||||
|
@ -66,8 +66,8 @@ Advanced use
|
|||
:tag: advanced
|
||||
|
||||
.. displayitem::
|
||||
:header: Customize configs for complex projects
|
||||
:description: Learn how to connect complex projects with each Registry.
|
||||
:header: Customize for complex projects
|
||||
:description: Learn how to implement CLIs for complex projects.
|
||||
:col_css: col-md-6
|
||||
:button_link: lightning_cli_advanced_3.html
|
||||
:height: 150
|
||||
|
|
|
@ -63,29 +63,6 @@ This can be useful to implement custom logic without having to subclass the CLI,
|
|||
and argument parsing capabilities.
|
||||
|
||||
|
||||
Subclass registration
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
To use shorthand notation, the options need to be registered beforehand. This can be easily done with:
|
||||
|
||||
.. code-block::
|
||||
|
||||
LightningCLI(auto_registry=True) # False by default
|
||||
|
||||
which will register all subclasses of :class:`torch.optim.Optimizer`, :class:`torch.optim.lr_scheduler._LRScheduler`,
|
||||
:class:`~pytorch_lightning.core.module.LightningModule`,
|
||||
:class:`~pytorch_lightning.core.datamodule.LightningDataModule`, :class:`~pytorch_lightning.callbacks.Callback`, and
|
||||
:class:`~pytorch_lightning.loggers.LightningLoggerBase` across all imported modules. This includes those in your own
|
||||
code.
|
||||
|
||||
Alternatively, if this is left unset, only the subclasses defined in PyTorch's :class:`torch.optim.Optimizer`,
|
||||
:class:`torch.optim.lr_scheduler._LRScheduler` and Lightning's :class:`~pytorch_lightning.callbacks.Callback` and
|
||||
:class:`~pytorch_lightning.loggers.LightningLoggerBase` subclassess will be registered.
|
||||
|
||||
In subsequent sections, we will go over adding specific classes to specific registries as well as how to use
|
||||
shorthand notation.
|
||||
|
||||
|
||||
Trainer Callbacks and arguments with class type
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
@ -107,14 +84,14 @@ file example that defines a couple of callbacks is the following:
|
|||
init_args:
|
||||
...
|
||||
|
||||
Similar to the callbacks, any arguments in :class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended
|
||||
Similar to the callbacks, any parameter in :class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended
|
||||
:class:`~pytorch_lightning.core.module.LightningModule` and
|
||||
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class can be configured
|
||||
the same way using :code:`class_path` and :code:`init_args`.
|
||||
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class, can be
|
||||
configured the same way using :code:`class_path` and :code:`init_args`. If the package that defines a subclass is
|
||||
imported before the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class is run, the name can be used instead of
|
||||
the full import path.
|
||||
|
||||
For callbacks in particular, Lightning simplifies the command line so that only
|
||||
the :class:`~pytorch_lightning.callbacks.Callback` name is required.
|
||||
The argument's order matters and the user needs to pass the arguments in the following way.
|
||||
From command line the syntax is the following:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
|
@ -127,7 +104,8 @@ The argument's order matters and the user needs to pass the arguments in the fol
|
|||
--trainer.callbacks.{CALLBACK_N_ARGS_1}=... \
|
||||
...
|
||||
|
||||
Here is an example:
|
||||
Note the use of ``+`` to append a new callback to the list and that the ``init_args`` are applied to the previous
|
||||
callback appended. Here is an example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
|
@ -137,43 +115,11 @@ Here is an example:
|
|||
--trainer.callbacks+=LearningRateMonitor \
|
||||
--trainer.callbacks.logging_interval=epoch
|
||||
|
||||
Lightning provides a mechanism for you to add your own callbacks and benefit from the command line simplification
|
||||
as described above:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY
|
||||
|
||||
|
||||
@CALLBACK_REGISTRY
|
||||
class CustomCallback(Callback):
|
||||
...
|
||||
|
||||
|
||||
cli = LightningCLI(...)
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python ... --trainer.callbacks+=CustomCallback ...
|
||||
|
||||
.. note::
|
||||
|
||||
This shorthand notation is also supported inside a configuration file. The configuration file
|
||||
generated by calling the previous command with ``--print_config`` will have the full ``class_path`` notation.
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
trainer:
|
||||
callbacks:
|
||||
- class_path: your_class_path.CustomCallback
|
||||
init_args:
|
||||
...
|
||||
|
||||
|
||||
.. tip::
|
||||
|
||||
``--trainer.logger`` also supports shorthand notation and a ``LOGGER_REGISTRY`` is available to register custom
|
||||
Loggers.
|
||||
Serialized config files (e.g. ``--print_config`` or :class:`~pytorch_lightning.utilities.cli.SaveConfigCallback`)
|
||||
always have the full ``class_path``'s, even when class name shorthand notation is used in command line or in input
|
||||
config files.
|
||||
|
||||
|
||||
Multiple models and/or datasets
|
||||
|
@ -377,12 +323,8 @@ example can be when one wants to add support for multiple optimizers:
|
|||
|
||||
class MyLightningCLI(LightningCLI):
|
||||
def add_arguments_to_parser(self, parser):
|
||||
parser.add_optimizer_args(
|
||||
OPTIMIZER_REGISTRY.classes, nested_key="gen_optimizer", link_to="model.optimizer1_init"
|
||||
)
|
||||
parser.add_optimizer_args(
|
||||
OPTIMIZER_REGISTRY.classes, nested_key="gen_discriminator", link_to="model.optimizer2_init"
|
||||
)
|
||||
parser.add_optimizer_args(nested_key="optimizer1", link_to="model.optimizer1_init")
|
||||
parser.add_optimizer_args(nested_key="optimizer2", link_to="model.optimizer2_init")
|
||||
|
||||
|
||||
cli = MyLightningCLI(MyModel)
|
||||
|
@ -398,18 +340,17 @@ With shorthand notation:
|
|||
.. code-block:: bash
|
||||
|
||||
$ python trainer.py fit \
|
||||
--gen_optimizer=Adam \
|
||||
--gen_optimizer.lr=0.01 \
|
||||
--gen_discriminator=AdamW \
|
||||
--gen_discriminator.lr=0.0001
|
||||
--optimizer1=Adam \
|
||||
--optimizer1.lr=0.01 \
|
||||
--optimizer2=AdamW \
|
||||
--optimizer2.lr=0.0001
|
||||
|
||||
You can also pass the class path directly, for example, if the optimizer hasn't been registered to the
|
||||
``OPTIMIZER_REGISTRY``:
|
||||
You can also pass the class path directly, for example, if the optimizer hasn't been imported:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python trainer.py fit \
|
||||
--gen_optimizer.class_path=torch.optim.Adam \
|
||||
--gen_optimizer.init_args.lr=0.01 \
|
||||
--gen_discriminator.class_path=torch.optim.AdamW \
|
||||
--gen_discriminator.init_args.lr=0.0001
|
||||
--optimizer1=torch.optim.Adam \
|
||||
--optimizer1.lr=0.01 \
|
||||
--optimizer2=torch.optim.AdamW \
|
||||
--optimizer2.lr=0.0001
|
||||
|
|
|
@ -46,9 +46,9 @@ This is what the Lightning CLI enables. Otherwise, this kind of configuration re
|
|||
----
|
||||
|
||||
*************************
|
||||
Register LightningModules
|
||||
Multiple LightningModules
|
||||
*************************
|
||||
Connect models across different files with the ``MODEL_REGISTRY`` to make them available from the CLI:
|
||||
To support multiple models, when instantiating ``LightningCLI`` omit the ``model_class`` parameter:
|
||||
|
||||
.. code:: python
|
||||
|
||||
|
@ -58,14 +58,12 @@ Connect models across different files with the ``MODEL_REGISTRY`` to make them a
|
|||
from pytorch_lightning.utilities import cli as pl_cli
|
||||
|
||||
|
||||
@pl_cli.MODEL_REGISTRY
|
||||
class Model1(DemoModel):
|
||||
def configure_optimizers(self):
|
||||
print("⚡", "using Model1", "⚡")
|
||||
return super().configure_optimizers()
|
||||
|
||||
|
||||
@pl_cli.MODEL_REGISTRY
|
||||
class Model2(DemoModel):
|
||||
def configure_optimizers(self):
|
||||
print("⚡", "using Model2", "⚡")
|
||||
|
@ -87,9 +85,9 @@ Now you can choose between any model from the CLI:
|
|||
----
|
||||
|
||||
********************
|
||||
Register DataModules
|
||||
Multiple DataModules
|
||||
********************
|
||||
Connect DataModules across different files with the ``DATAMODULE_REGISTRY`` to make them available from the CLI:
|
||||
To support multiple data modules, when instantiating ``LightningCLI`` omit the ``datamodule_class`` parameter:
|
||||
|
||||
.. code:: python
|
||||
|
||||
|
@ -99,14 +97,12 @@ Connect DataModules across different files with the ``DATAMODULE_REGISTRY`` to m
|
|||
from pytorch_lightning import demos
|
||||
|
||||
|
||||
@pl_cli.DATAMODULE_REGISTRY
|
||||
class FakeDataset1(BoringDataModule):
|
||||
def train_dataloader(self):
|
||||
print("⚡", "using FakeDataset1", "⚡")
|
||||
return torch.utils.data.DataLoader(self.random_train)
|
||||
|
||||
|
||||
@pl_cli.DATAMODULE_REGISTRY
|
||||
class FakeDataset2(BoringDataModule):
|
||||
def train_dataloader(self):
|
||||
print("⚡", "using FakeDataset2", "⚡")
|
||||
|
@ -127,10 +123,10 @@ Now you can choose between any dataset at runtime:
|
|||
|
||||
----
|
||||
|
||||
*******************
|
||||
Register optimizers
|
||||
*******************
|
||||
Connect optimizers with the ``OPTIMIZER_REGISTRY`` to make them available from the CLI:
|
||||
*****************
|
||||
Custom optimizers
|
||||
*****************
|
||||
Any subclass of ``torch.optim.Optimizer`` can be used as an optimizer:
|
||||
|
||||
.. code:: python
|
||||
|
||||
|
@ -140,14 +136,12 @@ Connect optimizers with the ``OPTIMIZER_REGISTRY`` to make them available from t
|
|||
from pytorch_lightning import demos
|
||||
|
||||
|
||||
@pl_cli.OPTIMIZER_REGISTRY
|
||||
class LitAdam(torch.optim.Adam):
|
||||
def step(self, closure):
|
||||
print("⚡", "using LitAdam", "⚡")
|
||||
super().step(closure)
|
||||
|
||||
|
||||
@pl_cli.OPTIMIZER_REGISTRY
|
||||
class FancyAdam(torch.optim.Adam):
|
||||
def step(self, closure):
|
||||
print("⚡", "using FancyAdam", "⚡")
|
||||
|
@ -166,7 +160,8 @@ Now you can choose between any optimizer at runtime:
|
|||
# use FancyAdam
|
||||
python main.py fit --optimizer FancyAdam
|
||||
|
||||
Bonus: If you need only 1 optimizer, the Lightning CLI already works out of the box with any Optimizer from ``torch.optim.optim``:
|
||||
Bonus: If you need only 1 optimizer, the Lightning CLI already works out of the box with any Optimizer from
|
||||
``torch.optim``:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
|
@ -180,10 +175,10 @@ If the optimizer you want needs other arguments, add them via the CLI (no need t
|
|||
|
||||
----
|
||||
|
||||
**********************
|
||||
Register LR schedulers
|
||||
**********************
|
||||
Connect learning rate schedulers with the ``LR_SCHEDULER_REGISTRY`` to make them available from the CLI:
|
||||
********************
|
||||
Custom LR schedulers
|
||||
********************
|
||||
Any subclass of ``torch.optim.lr_scheduler._LRScheduler`` can be used as learning rate scheduler:
|
||||
|
||||
.. code:: python
|
||||
|
||||
|
@ -193,7 +188,6 @@ Connect learning rate schedulers with the ``LR_SCHEDULER_REGISTRY`` to make them
|
|||
from pytorch_lightning import demos
|
||||
|
||||
|
||||
@pl_cli.LR_SCHEDULER_REGISTRY
|
||||
class LitLRScheduler(torch.optim.lr_scheduler.CosineAnnealingLR):
|
||||
def step(self):
|
||||
print("⚡", "using LitLRScheduler", "⚡")
|
||||
|
@ -210,7 +204,8 @@ Now you can choose between any learning rate scheduler at runtime:
|
|||
python main.py fit --lr_scheduler LitLRScheduler
|
||||
|
||||
|
||||
Bonus: If you need only 1 LRScheduler, the Lightning CLI already works out of the box with any LRScheduler from ``torch.optim``:
|
||||
Bonus: If you need only 1 LRScheduler, the Lightning CLI already works out of the box with any LRScheduler from
|
||||
``torch.optim``:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
|
@ -226,26 +221,31 @@ If the scheduler you want needs other arguments, add them via the CLI (no need t
|
|||
|
||||
----
|
||||
|
||||
*************************
|
||||
Register from any package
|
||||
*************************
|
||||
A shortcut to register many classes from a package is to use the ``register_classes`` method. Here we register all optimizers from the ``torch.optim`` library:
|
||||
************************
|
||||
Classes from any package
|
||||
************************
|
||||
In the previous sections the classes to select were defined in the same python file where the ``LightningCLI`` class is
|
||||
run. To select classes from any package by using only the class name, import the respective package:
|
||||
|
||||
.. code:: python
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.utilities import cli as pl_cli
|
||||
from pytorch_lightning import demos
|
||||
import my_code.models # noqa: F401
|
||||
import my_code.data_modules # noqa: F401
|
||||
import my_code.optimizers # noqa: F401
|
||||
|
||||
# add all PyTorch optimizers!
|
||||
pl_cli.OPTIMIZER_REGISTRY.register_classes(module=torch.optim, base_cls=torch.optim.Optimizer)
|
||||
cli = pl_cli.LightningCLI()
|
||||
|
||||
cli = pl_cli.LightningCLI(DemoModel, BoringDataModule)
|
||||
|
||||
Now use any of the optimizers in the ``torch.optim`` library:
|
||||
Now use any of the classes:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
python main.py fit --optimizer AdamW
|
||||
python main.py fit --model Model1 --data FakeDataset1 --optimizer LitAdam --lr_scheduler LitLRScheduler
|
||||
|
||||
This method is supported by all the registry classes.
|
||||
The ``# noqa: F401`` comment avoids a linter warning that the import is unused. It is also possible to select subclasses
|
||||
that have not been imported by giving the full import path:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
python main.py fit --model my_code.models.Model1
|
||||
|
|
|
@ -51,8 +51,23 @@ else:
|
|||
locals()["Namespace"] = object
|
||||
|
||||
|
||||
class _Registry(dict):
|
||||
def __call__(self, cls: Type, key: Optional[str] = None, override: bool = False) -> Type:
|
||||
_deprecate_registry_message = (
|
||||
"`LightningCLI`'s registries were deprecated in v1.7 and will be removed "
|
||||
"in v1.9. Now any imported subclass is automatically available by name in "
|
||||
"`LightningCLI` without any need to explicitly register it."
|
||||
)
|
||||
|
||||
_deprecate_auto_registry_message = (
|
||||
"`LightningCLI.auto_registry` parameter was deprecated in v1.7 and will be removed "
|
||||
"in v1.9. Now any imported subclass is automatically available by name in "
|
||||
"`LightningCLI` without any need to explicitly register it."
|
||||
)
|
||||
|
||||
|
||||
class _Registry(dict): # Remove in v1.9
|
||||
def __call__(
|
||||
self, cls: Type, key: Optional[str] = None, override: bool = False, show_deprecation: bool = True
|
||||
) -> Type:
|
||||
"""Registers a class mapped to a name.
|
||||
|
||||
Args:
|
||||
|
@ -67,12 +82,16 @@ class _Registry(dict):
|
|||
|
||||
if key not in self or override:
|
||||
self[key] = cls
|
||||
|
||||
self._deprecation(show_deprecation)
|
||||
return cls
|
||||
|
||||
def register_classes(self, module: ModuleType, base_cls: Type, override: bool = False) -> None:
|
||||
def register_classes(
|
||||
self, module: ModuleType, base_cls: Type, override: bool = False, show_deprecation: bool = True
|
||||
) -> None:
|
||||
"""This function is an utility to register all classes from a module."""
|
||||
for cls in self.get_members(module, base_cls):
|
||||
self(cls=cls, override=override)
|
||||
self(cls=cls, override=override, show_deprecation=show_deprecation)
|
||||
|
||||
@staticmethod
|
||||
def get_members(module: ModuleType, base_cls: Type) -> Generator[Type, None, None]:
|
||||
|
@ -85,16 +104,23 @@ class _Registry(dict):
|
|||
@property
|
||||
def names(self) -> List[str]:
|
||||
"""Returns the registered names."""
|
||||
self._deprecation()
|
||||
return list(self.keys())
|
||||
|
||||
@property
|
||||
def classes(self) -> Tuple[Type, ...]:
|
||||
"""Returns the registered classes."""
|
||||
self._deprecation()
|
||||
return tuple(self.values())
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Registered objects: {self.names}"
|
||||
|
||||
def _deprecation(self, show_deprecation: bool = True) -> None:
|
||||
if show_deprecation and not getattr(self, "deprecation_shown", False):
|
||||
rank_zero_deprecation(_deprecate_registry_message)
|
||||
self.deprecation_shown = True
|
||||
|
||||
|
||||
OPTIMIZER_REGISTRY = _Registry()
|
||||
LR_SCHEDULER_REGISTRY = _Registry()
|
||||
|
@ -116,29 +142,32 @@ LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, ReduceLROnPl
|
|||
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[ReduceLROnPlateau]]
|
||||
|
||||
|
||||
def _populate_registries(subclasses: bool) -> None:
|
||||
def _populate_registries(subclasses: bool) -> None: # Remove in v1.9
|
||||
if subclasses:
|
||||
rank_zero_deprecation(_deprecate_auto_registry_message)
|
||||
# this will register any subclasses from all loaded modules including userland
|
||||
for cls in get_all_subclasses(torch.optim.Optimizer):
|
||||
OPTIMIZER_REGISTRY(cls)
|
||||
OPTIMIZER_REGISTRY(cls, show_deprecation=False)
|
||||
for cls in get_all_subclasses(torch.optim.lr_scheduler._LRScheduler):
|
||||
LR_SCHEDULER_REGISTRY(cls)
|
||||
LR_SCHEDULER_REGISTRY(cls, show_deprecation=False)
|
||||
for cls in get_all_subclasses(pl.Callback):
|
||||
CALLBACK_REGISTRY(cls)
|
||||
CALLBACK_REGISTRY(cls, show_deprecation=False)
|
||||
for cls in get_all_subclasses(pl.LightningModule):
|
||||
MODEL_REGISTRY(cls)
|
||||
MODEL_REGISTRY(cls, show_deprecation=False)
|
||||
for cls in get_all_subclasses(pl.LightningDataModule):
|
||||
DATAMODULE_REGISTRY(cls)
|
||||
DATAMODULE_REGISTRY(cls, show_deprecation=False)
|
||||
for cls in get_all_subclasses(pl.loggers.Logger):
|
||||
LOGGER_REGISTRY(cls)
|
||||
LOGGER_REGISTRY(cls, show_deprecation=False)
|
||||
else:
|
||||
# manually register torch's subclasses and our subclasses
|
||||
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer)
|
||||
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)
|
||||
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.Callback)
|
||||
LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.Logger)
|
||||
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer, show_deprecation=False)
|
||||
LR_SCHEDULER_REGISTRY.register_classes(
|
||||
torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler, show_deprecation=False
|
||||
)
|
||||
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.Callback, show_deprecation=False)
|
||||
LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.Logger, show_deprecation=False)
|
||||
# `ReduceLROnPlateau` does not subclass `_LRScheduler`
|
||||
LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau)
|
||||
LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau, show_deprecation=False)
|
||||
|
||||
|
||||
class LightningArgumentParser(ArgumentParser):
|
||||
|
@ -211,14 +240,14 @@ class LightningArgumentParser(ArgumentParser):
|
|||
|
||||
def add_optimizer_args(
|
||||
self,
|
||||
optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]],
|
||||
optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]] = (Optimizer,),
|
||||
nested_key: str = "optimizer",
|
||||
link_to: str = "AUTOMATIC",
|
||||
) -> None:
|
||||
"""Adds arguments from an optimizer class to a nested key of the parser.
|
||||
|
||||
Args:
|
||||
optimizer_class: Any subclass of :class:`torch.optim.Optimizer`.
|
||||
optimizer_class: Any subclass of :class:`torch.optim.Optimizer`. Use tuple to allow subclasses.
|
||||
nested_key: Name of the nested namespace to store arguments.
|
||||
link_to: Dot notation of a parser key to set arguments or AUTOMATIC.
|
||||
"""
|
||||
|
@ -235,14 +264,15 @@ class LightningArgumentParser(ArgumentParser):
|
|||
|
||||
def add_lr_scheduler_args(
|
||||
self,
|
||||
lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]],
|
||||
lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]] = LRSchedulerTypeTuple,
|
||||
nested_key: str = "lr_scheduler",
|
||||
link_to: str = "AUTOMATIC",
|
||||
) -> None:
|
||||
"""Adds arguments from a learning rate scheduler class to a nested key of the parser.
|
||||
|
||||
Args:
|
||||
lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``.
|
||||
lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``. Use
|
||||
tuple to allow subclasses.
|
||||
nested_key: Name of the nested namespace to store arguments.
|
||||
link_to: Dot notation of a parser key to set arguments or AUTOMATIC.
|
||||
"""
|
||||
|
|
|
@ -19,7 +19,14 @@ import pytest
|
|||
import pytorch_lightning.loggers.base as logger_base
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.core.module import LightningModule
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||
from pytorch_lightning.utilities.cli import (
|
||||
_deprecate_auto_registry_message,
|
||||
_deprecate_registry_message,
|
||||
CALLBACK_REGISTRY,
|
||||
LightningCLI,
|
||||
SaveConfigCallback,
|
||||
)
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||
|
||||
|
||||
|
@ -133,3 +140,17 @@ def test_deprecated_dataloader_reset():
|
|||
trainer = Trainer()
|
||||
with pytest.deprecated_call(match="reset_train_val_dataloaders` has been deprecated in v1.7"):
|
||||
trainer.reset_train_val_dataloaders()
|
||||
|
||||
|
||||
def test_lightningCLI_registries_register():
|
||||
with pytest.deprecated_call(match=_deprecate_registry_message):
|
||||
|
||||
@CALLBACK_REGISTRY
|
||||
class CustomCallback(SaveConfigCallback):
|
||||
pass
|
||||
|
||||
|
||||
def test_lightningCLI_registries_register_automatically():
|
||||
with pytest.deprecated_call(match=_deprecate_auto_registry_message):
|
||||
with mock.patch("sys.argv", ["any.py"]):
|
||||
LightningCLI(BoringModel, run=False, auto_registry=True)
|
||||
|
|
|
@ -34,22 +34,15 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
|
|||
from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, seed_everything, Trainer
|
||||
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
||||
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
|
||||
from pytorch_lightning.loggers import Logger, TensorBoardLogger
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.cli import (
|
||||
_populate_registries,
|
||||
CALLBACK_REGISTRY,
|
||||
DATAMODULE_REGISTRY,
|
||||
instantiate_class,
|
||||
LightningArgumentParser,
|
||||
LightningCLI,
|
||||
LOGGER_REGISTRY,
|
||||
LR_SCHEDULER_REGISTRY,
|
||||
LRSchedulerTypeTuple,
|
||||
MODEL_REGISTRY,
|
||||
OPTIMIZER_REGISTRY,
|
||||
SaveConfigCallback,
|
||||
)
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -914,72 +907,6 @@ def test_lightning_cli_run():
|
|||
assert isinstance(cli.model, LightningModule)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_registries():
|
||||
# since the registries are global, it's good to clear them after each test to avoid unwanted interactions
|
||||
yield
|
||||
OPTIMIZER_REGISTRY.clear()
|
||||
LR_SCHEDULER_REGISTRY.clear()
|
||||
CALLBACK_REGISTRY.clear()
|
||||
MODEL_REGISTRY.clear()
|
||||
DATAMODULE_REGISTRY.clear()
|
||||
LOGGER_REGISTRY.clear()
|
||||
|
||||
|
||||
def test_registries():
|
||||
# the registries are global so this is only necessary when this test is run standalone
|
||||
_populate_registries(False)
|
||||
|
||||
@OPTIMIZER_REGISTRY
|
||||
class CustomAdam(torch.optim.Adam):
|
||||
pass
|
||||
|
||||
@LR_SCHEDULER_REGISTRY
|
||||
class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
|
||||
pass
|
||||
|
||||
@CALLBACK_REGISTRY
|
||||
class CustomCallback(Callback):
|
||||
pass
|
||||
|
||||
@LOGGER_REGISTRY
|
||||
class CustomLogger(Logger):
|
||||
pass
|
||||
|
||||
assert "SGD" in OPTIMIZER_REGISTRY.names
|
||||
assert "RMSprop" in OPTIMIZER_REGISTRY.names
|
||||
assert "CustomAdam" in OPTIMIZER_REGISTRY.names
|
||||
|
||||
assert "CosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names
|
||||
assert "CosineAnnealingWarmRestarts" in LR_SCHEDULER_REGISTRY.names
|
||||
assert "CustomCosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names
|
||||
assert "ReduceLROnPlateau" in LR_SCHEDULER_REGISTRY.names
|
||||
|
||||
assert "EarlyStopping" in CALLBACK_REGISTRY.names
|
||||
assert "CustomCallback" in CALLBACK_REGISTRY.names
|
||||
|
||||
class Foo:
|
||||
...
|
||||
|
||||
OPTIMIZER_REGISTRY(Foo, key="SGD") # not overridden by default
|
||||
assert OPTIMIZER_REGISTRY["SGD"] is torch.optim.SGD
|
||||
OPTIMIZER_REGISTRY(Foo, key="SGD", override=True)
|
||||
assert OPTIMIZER_REGISTRY["SGD"] is Foo
|
||||
|
||||
# test `_Registry.__call__` returns the class
|
||||
assert isinstance(CustomCallback(), CustomCallback)
|
||||
|
||||
assert "WandbLogger" in LOGGER_REGISTRY
|
||||
assert "CustomLogger" in LOGGER_REGISTRY
|
||||
|
||||
|
||||
def test_registries_register_automatically():
|
||||
assert "SaveConfigCallback" not in CALLBACK_REGISTRY
|
||||
with mock.patch("sys.argv", ["any.py"]):
|
||||
LightningCLI(BoringModel, run=False, auto_registry=True)
|
||||
assert "SaveConfigCallback" in CALLBACK_REGISTRY
|
||||
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def __init__(self, foo, bar=5):
|
||||
super().__init__()
|
||||
|
@ -1137,11 +1064,11 @@ def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload
|
|||
super().__init__(*args, run=False)
|
||||
|
||||
def add_arguments_to_parser(self, parser):
|
||||
parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes, nested_key="opt1", link_to="model.opt1_config")
|
||||
parser.add_optimizer_args(nested_key="opt1", link_to="model.opt1_config")
|
||||
parser.add_optimizer_args(
|
||||
(torch.optim.ASGD, torch.optim.SGD), nested_key="opt2", link_to="model.opt2_config"
|
||||
)
|
||||
parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes, link_to="model.sch_config")
|
||||
parser.add_lr_scheduler_args(link_to="model.sch_config")
|
||||
parser.add_argument("--something", type=str, nargs="+")
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
|
Loading…
Reference in New Issue