Deprecate CLI registries and update documentation (#13221)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
ad87d2cad0
commit
0ae9627bf8
|
@ -130,6 +130,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
|
|
||||||
- Deprecated `Trainer.reset_train_val_dataloaders()` in favor of `Trainer.reset_{train,val}_dataloader` ([#12184](https://github.com/PyTorchLightning/pytorch-lightning/pull/12184))
|
- Deprecated `Trainer.reset_train_val_dataloaders()` in favor of `Trainer.reset_{train,val}_dataloader` ([#12184](https://github.com/PyTorchLightning/pytorch-lightning/pull/12184))
|
||||||
|
|
||||||
|
|
||||||
|
- Deprecated LightningCLI's registries in favor of importing the respective package ([#13221](https://github.com/PyTorchLightning/pytorch-lightning/pull/13221))
|
||||||
|
|
||||||
|
|
||||||
### Removed
|
### Removed
|
||||||
|
|
||||||
- Removed the deprecated `Logger.close` method ([#13149](https://github.com/PyTorchLightning/pytorch-lightning/pull/13149))
|
- Removed the deprecated `Logger.close` method ([#13149](https://github.com/PyTorchLightning/pytorch-lightning/pull/13149))
|
||||||
|
|
|
@ -27,7 +27,7 @@ Basic use
|
||||||
|
|
||||||
.. displayitem::
|
.. displayitem::
|
||||||
:header: 2: Mix models and datasets
|
:header: 2: Mix models and datasets
|
||||||
:description: Register models, datasets, optimizers and learning rate schedulers
|
:description: Support multiple models, datasets, optimizers and learning rate schedulers
|
||||||
:col_css: col-md-4
|
:col_css: col-md-4
|
||||||
:button_link: lightning_cli_intermediate_2.html
|
:button_link: lightning_cli_intermediate_2.html
|
||||||
:height: 150
|
:height: 150
|
||||||
|
@ -66,8 +66,8 @@ Advanced use
|
||||||
:tag: advanced
|
:tag: advanced
|
||||||
|
|
||||||
.. displayitem::
|
.. displayitem::
|
||||||
:header: Customize configs for complex projects
|
:header: Customize for complex projects
|
||||||
:description: Learn how to connect complex projects with each Registry.
|
:description: Learn how to implement CLIs for complex projects.
|
||||||
:col_css: col-md-6
|
:col_css: col-md-6
|
||||||
:button_link: lightning_cli_advanced_3.html
|
:button_link: lightning_cli_advanced_3.html
|
||||||
:height: 150
|
:height: 150
|
||||||
|
|
|
@ -63,29 +63,6 @@ This can be useful to implement custom logic without having to subclass the CLI,
|
||||||
and argument parsing capabilities.
|
and argument parsing capabilities.
|
||||||
|
|
||||||
|
|
||||||
Subclass registration
|
|
||||||
^^^^^^^^^^^^^^^^^^^^^
|
|
||||||
|
|
||||||
To use shorthand notation, the options need to be registered beforehand. This can be easily done with:
|
|
||||||
|
|
||||||
.. code-block::
|
|
||||||
|
|
||||||
LightningCLI(auto_registry=True) # False by default
|
|
||||||
|
|
||||||
which will register all subclasses of :class:`torch.optim.Optimizer`, :class:`torch.optim.lr_scheduler._LRScheduler`,
|
|
||||||
:class:`~pytorch_lightning.core.module.LightningModule`,
|
|
||||||
:class:`~pytorch_lightning.core.datamodule.LightningDataModule`, :class:`~pytorch_lightning.callbacks.Callback`, and
|
|
||||||
:class:`~pytorch_lightning.loggers.LightningLoggerBase` across all imported modules. This includes those in your own
|
|
||||||
code.
|
|
||||||
|
|
||||||
Alternatively, if this is left unset, only the subclasses defined in PyTorch's :class:`torch.optim.Optimizer`,
|
|
||||||
:class:`torch.optim.lr_scheduler._LRScheduler` and Lightning's :class:`~pytorch_lightning.callbacks.Callback` and
|
|
||||||
:class:`~pytorch_lightning.loggers.LightningLoggerBase` subclassess will be registered.
|
|
||||||
|
|
||||||
In subsequent sections, we will go over adding specific classes to specific registries as well as how to use
|
|
||||||
shorthand notation.
|
|
||||||
|
|
||||||
|
|
||||||
Trainer Callbacks and arguments with class type
|
Trainer Callbacks and arguments with class type
|
||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
@ -107,14 +84,14 @@ file example that defines a couple of callbacks is the following:
|
||||||
init_args:
|
init_args:
|
||||||
...
|
...
|
||||||
|
|
||||||
Similar to the callbacks, any arguments in :class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended
|
Similar to the callbacks, any parameter in :class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended
|
||||||
:class:`~pytorch_lightning.core.module.LightningModule` and
|
:class:`~pytorch_lightning.core.module.LightningModule` and
|
||||||
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class can be configured
|
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class, can be
|
||||||
the same way using :code:`class_path` and :code:`init_args`.
|
configured the same way using :code:`class_path` and :code:`init_args`. If the package that defines a subclass is
|
||||||
|
imported before the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class is run, the name can be used instead of
|
||||||
|
the full import path.
|
||||||
|
|
||||||
For callbacks in particular, Lightning simplifies the command line so that only
|
From command line the syntax is the following:
|
||||||
the :class:`~pytorch_lightning.callbacks.Callback` name is required.
|
|
||||||
The argument's order matters and the user needs to pass the arguments in the following way.
|
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
|
@ -127,7 +104,8 @@ The argument's order matters and the user needs to pass the arguments in the fol
|
||||||
--trainer.callbacks.{CALLBACK_N_ARGS_1}=... \
|
--trainer.callbacks.{CALLBACK_N_ARGS_1}=... \
|
||||||
...
|
...
|
||||||
|
|
||||||
Here is an example:
|
Note the use of ``+`` to append a new callback to the list and that the ``init_args`` are applied to the previous
|
||||||
|
callback appended. Here is an example:
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
|
@ -137,43 +115,11 @@ Here is an example:
|
||||||
--trainer.callbacks+=LearningRateMonitor \
|
--trainer.callbacks+=LearningRateMonitor \
|
||||||
--trainer.callbacks.logging_interval=epoch
|
--trainer.callbacks.logging_interval=epoch
|
||||||
|
|
||||||
Lightning provides a mechanism for you to add your own callbacks and benefit from the command line simplification
|
|
||||||
as described above:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY
|
|
||||||
|
|
||||||
|
|
||||||
@CALLBACK_REGISTRY
|
|
||||||
class CustomCallback(Callback):
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
cli = LightningCLI(...)
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
$ python ... --trainer.callbacks+=CustomCallback ...
|
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
This shorthand notation is also supported inside a configuration file. The configuration file
|
Serialized config files (e.g. ``--print_config`` or :class:`~pytorch_lightning.utilities.cli.SaveConfigCallback`)
|
||||||
generated by calling the previous command with ``--print_config`` will have the full ``class_path`` notation.
|
always have the full ``class_path``'s, even when class name shorthand notation is used in command line or in input
|
||||||
|
config files.
|
||||||
.. code-block:: yaml
|
|
||||||
|
|
||||||
trainer:
|
|
||||||
callbacks:
|
|
||||||
- class_path: your_class_path.CustomCallback
|
|
||||||
init_args:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
.. tip::
|
|
||||||
|
|
||||||
``--trainer.logger`` also supports shorthand notation and a ``LOGGER_REGISTRY`` is available to register custom
|
|
||||||
Loggers.
|
|
||||||
|
|
||||||
|
|
||||||
Multiple models and/or datasets
|
Multiple models and/or datasets
|
||||||
|
@ -377,12 +323,8 @@ example can be when one wants to add support for multiple optimizers:
|
||||||
|
|
||||||
class MyLightningCLI(LightningCLI):
|
class MyLightningCLI(LightningCLI):
|
||||||
def add_arguments_to_parser(self, parser):
|
def add_arguments_to_parser(self, parser):
|
||||||
parser.add_optimizer_args(
|
parser.add_optimizer_args(nested_key="optimizer1", link_to="model.optimizer1_init")
|
||||||
OPTIMIZER_REGISTRY.classes, nested_key="gen_optimizer", link_to="model.optimizer1_init"
|
parser.add_optimizer_args(nested_key="optimizer2", link_to="model.optimizer2_init")
|
||||||
)
|
|
||||||
parser.add_optimizer_args(
|
|
||||||
OPTIMIZER_REGISTRY.classes, nested_key="gen_discriminator", link_to="model.optimizer2_init"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
cli = MyLightningCLI(MyModel)
|
cli = MyLightningCLI(MyModel)
|
||||||
|
@ -398,18 +340,17 @@ With shorthand notation:
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
$ python trainer.py fit \
|
$ python trainer.py fit \
|
||||||
--gen_optimizer=Adam \
|
--optimizer1=Adam \
|
||||||
--gen_optimizer.lr=0.01 \
|
--optimizer1.lr=0.01 \
|
||||||
--gen_discriminator=AdamW \
|
--optimizer2=AdamW \
|
||||||
--gen_discriminator.lr=0.0001
|
--optimizer2.lr=0.0001
|
||||||
|
|
||||||
You can also pass the class path directly, for example, if the optimizer hasn't been registered to the
|
You can also pass the class path directly, for example, if the optimizer hasn't been imported:
|
||||||
``OPTIMIZER_REGISTRY``:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
$ python trainer.py fit \
|
$ python trainer.py fit \
|
||||||
--gen_optimizer.class_path=torch.optim.Adam \
|
--optimizer1=torch.optim.Adam \
|
||||||
--gen_optimizer.init_args.lr=0.01 \
|
--optimizer1.lr=0.01 \
|
||||||
--gen_discriminator.class_path=torch.optim.AdamW \
|
--optimizer2=torch.optim.AdamW \
|
||||||
--gen_discriminator.init_args.lr=0.0001
|
--optimizer2.lr=0.0001
|
||||||
|
|
|
@ -46,9 +46,9 @@ This is what the Lightning CLI enables. Otherwise, this kind of configuration re
|
||||||
----
|
----
|
||||||
|
|
||||||
*************************
|
*************************
|
||||||
Register LightningModules
|
Multiple LightningModules
|
||||||
*************************
|
*************************
|
||||||
Connect models across different files with the ``MODEL_REGISTRY`` to make them available from the CLI:
|
To support multiple models, when instantiating ``LightningCLI`` omit the ``model_class`` parameter:
|
||||||
|
|
||||||
.. code:: python
|
.. code:: python
|
||||||
|
|
||||||
|
@ -58,14 +58,12 @@ Connect models across different files with the ``MODEL_REGISTRY`` to make them a
|
||||||
from pytorch_lightning.utilities import cli as pl_cli
|
from pytorch_lightning.utilities import cli as pl_cli
|
||||||
|
|
||||||
|
|
||||||
@pl_cli.MODEL_REGISTRY
|
|
||||||
class Model1(DemoModel):
|
class Model1(DemoModel):
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
print("⚡", "using Model1", "⚡")
|
print("⚡", "using Model1", "⚡")
|
||||||
return super().configure_optimizers()
|
return super().configure_optimizers()
|
||||||
|
|
||||||
|
|
||||||
@pl_cli.MODEL_REGISTRY
|
|
||||||
class Model2(DemoModel):
|
class Model2(DemoModel):
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
print("⚡", "using Model2", "⚡")
|
print("⚡", "using Model2", "⚡")
|
||||||
|
@ -87,9 +85,9 @@ Now you can choose between any model from the CLI:
|
||||||
----
|
----
|
||||||
|
|
||||||
********************
|
********************
|
||||||
Register DataModules
|
Multiple DataModules
|
||||||
********************
|
********************
|
||||||
Connect DataModules across different files with the ``DATAMODULE_REGISTRY`` to make them available from the CLI:
|
To support multiple data modules, when instantiating ``LightningCLI`` omit the ``datamodule_class`` parameter:
|
||||||
|
|
||||||
.. code:: python
|
.. code:: python
|
||||||
|
|
||||||
|
@ -99,14 +97,12 @@ Connect DataModules across different files with the ``DATAMODULE_REGISTRY`` to m
|
||||||
from pytorch_lightning import demos
|
from pytorch_lightning import demos
|
||||||
|
|
||||||
|
|
||||||
@pl_cli.DATAMODULE_REGISTRY
|
|
||||||
class FakeDataset1(BoringDataModule):
|
class FakeDataset1(BoringDataModule):
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
print("⚡", "using FakeDataset1", "⚡")
|
print("⚡", "using FakeDataset1", "⚡")
|
||||||
return torch.utils.data.DataLoader(self.random_train)
|
return torch.utils.data.DataLoader(self.random_train)
|
||||||
|
|
||||||
|
|
||||||
@pl_cli.DATAMODULE_REGISTRY
|
|
||||||
class FakeDataset2(BoringDataModule):
|
class FakeDataset2(BoringDataModule):
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
print("⚡", "using FakeDataset2", "⚡")
|
print("⚡", "using FakeDataset2", "⚡")
|
||||||
|
@ -127,10 +123,10 @@ Now you can choose between any dataset at runtime:
|
||||||
|
|
||||||
----
|
----
|
||||||
|
|
||||||
*******************
|
*****************
|
||||||
Register optimizers
|
Custom optimizers
|
||||||
*******************
|
*****************
|
||||||
Connect optimizers with the ``OPTIMIZER_REGISTRY`` to make them available from the CLI:
|
Any subclass of ``torch.optim.Optimizer`` can be used as an optimizer:
|
||||||
|
|
||||||
.. code:: python
|
.. code:: python
|
||||||
|
|
||||||
|
@ -140,14 +136,12 @@ Connect optimizers with the ``OPTIMIZER_REGISTRY`` to make them available from t
|
||||||
from pytorch_lightning import demos
|
from pytorch_lightning import demos
|
||||||
|
|
||||||
|
|
||||||
@pl_cli.OPTIMIZER_REGISTRY
|
|
||||||
class LitAdam(torch.optim.Adam):
|
class LitAdam(torch.optim.Adam):
|
||||||
def step(self, closure):
|
def step(self, closure):
|
||||||
print("⚡", "using LitAdam", "⚡")
|
print("⚡", "using LitAdam", "⚡")
|
||||||
super().step(closure)
|
super().step(closure)
|
||||||
|
|
||||||
|
|
||||||
@pl_cli.OPTIMIZER_REGISTRY
|
|
||||||
class FancyAdam(torch.optim.Adam):
|
class FancyAdam(torch.optim.Adam):
|
||||||
def step(self, closure):
|
def step(self, closure):
|
||||||
print("⚡", "using FancyAdam", "⚡")
|
print("⚡", "using FancyAdam", "⚡")
|
||||||
|
@ -166,7 +160,8 @@ Now you can choose between any optimizer at runtime:
|
||||||
# use FancyAdam
|
# use FancyAdam
|
||||||
python main.py fit --optimizer FancyAdam
|
python main.py fit --optimizer FancyAdam
|
||||||
|
|
||||||
Bonus: If you need only 1 optimizer, the Lightning CLI already works out of the box with any Optimizer from ``torch.optim.optim``:
|
Bonus: If you need only 1 optimizer, the Lightning CLI already works out of the box with any Optimizer from
|
||||||
|
``torch.optim``:
|
||||||
|
|
||||||
.. code:: bash
|
.. code:: bash
|
||||||
|
|
||||||
|
@ -180,10 +175,10 @@ If the optimizer you want needs other arguments, add them via the CLI (no need t
|
||||||
|
|
||||||
----
|
----
|
||||||
|
|
||||||
**********************
|
********************
|
||||||
Register LR schedulers
|
Custom LR schedulers
|
||||||
**********************
|
********************
|
||||||
Connect learning rate schedulers with the ``LR_SCHEDULER_REGISTRY`` to make them available from the CLI:
|
Any subclass of ``torch.optim.lr_scheduler._LRScheduler`` can be used as learning rate scheduler:
|
||||||
|
|
||||||
.. code:: python
|
.. code:: python
|
||||||
|
|
||||||
|
@ -193,7 +188,6 @@ Connect learning rate schedulers with the ``LR_SCHEDULER_REGISTRY`` to make them
|
||||||
from pytorch_lightning import demos
|
from pytorch_lightning import demos
|
||||||
|
|
||||||
|
|
||||||
@pl_cli.LR_SCHEDULER_REGISTRY
|
|
||||||
class LitLRScheduler(torch.optim.lr_scheduler.CosineAnnealingLR):
|
class LitLRScheduler(torch.optim.lr_scheduler.CosineAnnealingLR):
|
||||||
def step(self):
|
def step(self):
|
||||||
print("⚡", "using LitLRScheduler", "⚡")
|
print("⚡", "using LitLRScheduler", "⚡")
|
||||||
|
@ -210,7 +204,8 @@ Now you can choose between any learning rate scheduler at runtime:
|
||||||
python main.py fit --lr_scheduler LitLRScheduler
|
python main.py fit --lr_scheduler LitLRScheduler
|
||||||
|
|
||||||
|
|
||||||
Bonus: If you need only 1 LRScheduler, the Lightning CLI already works out of the box with any LRScheduler from ``torch.optim``:
|
Bonus: If you need only 1 LRScheduler, the Lightning CLI already works out of the box with any LRScheduler from
|
||||||
|
``torch.optim``:
|
||||||
|
|
||||||
.. code:: bash
|
.. code:: bash
|
||||||
|
|
||||||
|
@ -226,26 +221,31 @@ If the scheduler you want needs other arguments, add them via the CLI (no need t
|
||||||
|
|
||||||
----
|
----
|
||||||
|
|
||||||
*************************
|
************************
|
||||||
Register from any package
|
Classes from any package
|
||||||
*************************
|
************************
|
||||||
A shortcut to register many classes from a package is to use the ``register_classes`` method. Here we register all optimizers from the ``torch.optim`` library:
|
In the previous sections the classes to select were defined in the same python file where the ``LightningCLI`` class is
|
||||||
|
run. To select classes from any package by using only the class name, import the respective package:
|
||||||
|
|
||||||
.. code:: python
|
.. code:: python
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch_lightning.utilities import cli as pl_cli
|
from pytorch_lightning.utilities import cli as pl_cli
|
||||||
from pytorch_lightning import demos
|
import my_code.models # noqa: F401
|
||||||
|
import my_code.data_modules # noqa: F401
|
||||||
|
import my_code.optimizers # noqa: F401
|
||||||
|
|
||||||
# add all PyTorch optimizers!
|
cli = pl_cli.LightningCLI()
|
||||||
pl_cli.OPTIMIZER_REGISTRY.register_classes(module=torch.optim, base_cls=torch.optim.Optimizer)
|
|
||||||
|
|
||||||
cli = pl_cli.LightningCLI(DemoModel, BoringDataModule)
|
Now use any of the classes:
|
||||||
|
|
||||||
Now use any of the optimizers in the ``torch.optim`` library:
|
|
||||||
|
|
||||||
.. code:: bash
|
.. code:: bash
|
||||||
|
|
||||||
python main.py fit --optimizer AdamW
|
python main.py fit --model Model1 --data FakeDataset1 --optimizer LitAdam --lr_scheduler LitLRScheduler
|
||||||
|
|
||||||
This method is supported by all the registry classes.
|
The ``# noqa: F401`` comment avoids a linter warning that the import is unused. It is also possible to select subclasses
|
||||||
|
that have not been imported by giving the full import path:
|
||||||
|
|
||||||
|
.. code:: bash
|
||||||
|
|
||||||
|
python main.py fit --model my_code.models.Model1
|
||||||
|
|
|
@ -51,8 +51,23 @@ else:
|
||||||
locals()["Namespace"] = object
|
locals()["Namespace"] = object
|
||||||
|
|
||||||
|
|
||||||
class _Registry(dict):
|
_deprecate_registry_message = (
|
||||||
def __call__(self, cls: Type, key: Optional[str] = None, override: bool = False) -> Type:
|
"`LightningCLI`'s registries were deprecated in v1.7 and will be removed "
|
||||||
|
"in v1.9. Now any imported subclass is automatically available by name in "
|
||||||
|
"`LightningCLI` without any need to explicitly register it."
|
||||||
|
)
|
||||||
|
|
||||||
|
_deprecate_auto_registry_message = (
|
||||||
|
"`LightningCLI.auto_registry` parameter was deprecated in v1.7 and will be removed "
|
||||||
|
"in v1.9. Now any imported subclass is automatically available by name in "
|
||||||
|
"`LightningCLI` without any need to explicitly register it."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _Registry(dict): # Remove in v1.9
|
||||||
|
def __call__(
|
||||||
|
self, cls: Type, key: Optional[str] = None, override: bool = False, show_deprecation: bool = True
|
||||||
|
) -> Type:
|
||||||
"""Registers a class mapped to a name.
|
"""Registers a class mapped to a name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -67,12 +82,16 @@ class _Registry(dict):
|
||||||
|
|
||||||
if key not in self or override:
|
if key not in self or override:
|
||||||
self[key] = cls
|
self[key] = cls
|
||||||
|
|
||||||
|
self._deprecation(show_deprecation)
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
def register_classes(self, module: ModuleType, base_cls: Type, override: bool = False) -> None:
|
def register_classes(
|
||||||
|
self, module: ModuleType, base_cls: Type, override: bool = False, show_deprecation: bool = True
|
||||||
|
) -> None:
|
||||||
"""This function is an utility to register all classes from a module."""
|
"""This function is an utility to register all classes from a module."""
|
||||||
for cls in self.get_members(module, base_cls):
|
for cls in self.get_members(module, base_cls):
|
||||||
self(cls=cls, override=override)
|
self(cls=cls, override=override, show_deprecation=show_deprecation)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_members(module: ModuleType, base_cls: Type) -> Generator[Type, None, None]:
|
def get_members(module: ModuleType, base_cls: Type) -> Generator[Type, None, None]:
|
||||||
|
@ -85,16 +104,23 @@ class _Registry(dict):
|
||||||
@property
|
@property
|
||||||
def names(self) -> List[str]:
|
def names(self) -> List[str]:
|
||||||
"""Returns the registered names."""
|
"""Returns the registered names."""
|
||||||
|
self._deprecation()
|
||||||
return list(self.keys())
|
return list(self.keys())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def classes(self) -> Tuple[Type, ...]:
|
def classes(self) -> Tuple[Type, ...]:
|
||||||
"""Returns the registered classes."""
|
"""Returns the registered classes."""
|
||||||
|
self._deprecation()
|
||||||
return tuple(self.values())
|
return tuple(self.values())
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"Registered objects: {self.names}"
|
return f"Registered objects: {self.names}"
|
||||||
|
|
||||||
|
def _deprecation(self, show_deprecation: bool = True) -> None:
|
||||||
|
if show_deprecation and not getattr(self, "deprecation_shown", False):
|
||||||
|
rank_zero_deprecation(_deprecate_registry_message)
|
||||||
|
self.deprecation_shown = True
|
||||||
|
|
||||||
|
|
||||||
OPTIMIZER_REGISTRY = _Registry()
|
OPTIMIZER_REGISTRY = _Registry()
|
||||||
LR_SCHEDULER_REGISTRY = _Registry()
|
LR_SCHEDULER_REGISTRY = _Registry()
|
||||||
|
@ -116,29 +142,32 @@ LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, ReduceLROnPl
|
||||||
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[ReduceLROnPlateau]]
|
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[ReduceLROnPlateau]]
|
||||||
|
|
||||||
|
|
||||||
def _populate_registries(subclasses: bool) -> None:
|
def _populate_registries(subclasses: bool) -> None: # Remove in v1.9
|
||||||
if subclasses:
|
if subclasses:
|
||||||
|
rank_zero_deprecation(_deprecate_auto_registry_message)
|
||||||
# this will register any subclasses from all loaded modules including userland
|
# this will register any subclasses from all loaded modules including userland
|
||||||
for cls in get_all_subclasses(torch.optim.Optimizer):
|
for cls in get_all_subclasses(torch.optim.Optimizer):
|
||||||
OPTIMIZER_REGISTRY(cls)
|
OPTIMIZER_REGISTRY(cls, show_deprecation=False)
|
||||||
for cls in get_all_subclasses(torch.optim.lr_scheduler._LRScheduler):
|
for cls in get_all_subclasses(torch.optim.lr_scheduler._LRScheduler):
|
||||||
LR_SCHEDULER_REGISTRY(cls)
|
LR_SCHEDULER_REGISTRY(cls, show_deprecation=False)
|
||||||
for cls in get_all_subclasses(pl.Callback):
|
for cls in get_all_subclasses(pl.Callback):
|
||||||
CALLBACK_REGISTRY(cls)
|
CALLBACK_REGISTRY(cls, show_deprecation=False)
|
||||||
for cls in get_all_subclasses(pl.LightningModule):
|
for cls in get_all_subclasses(pl.LightningModule):
|
||||||
MODEL_REGISTRY(cls)
|
MODEL_REGISTRY(cls, show_deprecation=False)
|
||||||
for cls in get_all_subclasses(pl.LightningDataModule):
|
for cls in get_all_subclasses(pl.LightningDataModule):
|
||||||
DATAMODULE_REGISTRY(cls)
|
DATAMODULE_REGISTRY(cls, show_deprecation=False)
|
||||||
for cls in get_all_subclasses(pl.loggers.Logger):
|
for cls in get_all_subclasses(pl.loggers.Logger):
|
||||||
LOGGER_REGISTRY(cls)
|
LOGGER_REGISTRY(cls, show_deprecation=False)
|
||||||
else:
|
else:
|
||||||
# manually register torch's subclasses and our subclasses
|
# manually register torch's subclasses and our subclasses
|
||||||
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer)
|
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer, show_deprecation=False)
|
||||||
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)
|
LR_SCHEDULER_REGISTRY.register_classes(
|
||||||
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.Callback)
|
torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler, show_deprecation=False
|
||||||
LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.Logger)
|
)
|
||||||
|
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.Callback, show_deprecation=False)
|
||||||
|
LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.Logger, show_deprecation=False)
|
||||||
# `ReduceLROnPlateau` does not subclass `_LRScheduler`
|
# `ReduceLROnPlateau` does not subclass `_LRScheduler`
|
||||||
LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau)
|
LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau, show_deprecation=False)
|
||||||
|
|
||||||
|
|
||||||
class LightningArgumentParser(ArgumentParser):
|
class LightningArgumentParser(ArgumentParser):
|
||||||
|
@ -211,14 +240,14 @@ class LightningArgumentParser(ArgumentParser):
|
||||||
|
|
||||||
def add_optimizer_args(
|
def add_optimizer_args(
|
||||||
self,
|
self,
|
||||||
optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]],
|
optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]] = (Optimizer,),
|
||||||
nested_key: str = "optimizer",
|
nested_key: str = "optimizer",
|
||||||
link_to: str = "AUTOMATIC",
|
link_to: str = "AUTOMATIC",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Adds arguments from an optimizer class to a nested key of the parser.
|
"""Adds arguments from an optimizer class to a nested key of the parser.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer_class: Any subclass of :class:`torch.optim.Optimizer`.
|
optimizer_class: Any subclass of :class:`torch.optim.Optimizer`. Use tuple to allow subclasses.
|
||||||
nested_key: Name of the nested namespace to store arguments.
|
nested_key: Name of the nested namespace to store arguments.
|
||||||
link_to: Dot notation of a parser key to set arguments or AUTOMATIC.
|
link_to: Dot notation of a parser key to set arguments or AUTOMATIC.
|
||||||
"""
|
"""
|
||||||
|
@ -235,14 +264,15 @@ class LightningArgumentParser(ArgumentParser):
|
||||||
|
|
||||||
def add_lr_scheduler_args(
|
def add_lr_scheduler_args(
|
||||||
self,
|
self,
|
||||||
lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]],
|
lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]] = LRSchedulerTypeTuple,
|
||||||
nested_key: str = "lr_scheduler",
|
nested_key: str = "lr_scheduler",
|
||||||
link_to: str = "AUTOMATIC",
|
link_to: str = "AUTOMATIC",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Adds arguments from a learning rate scheduler class to a nested key of the parser.
|
"""Adds arguments from a learning rate scheduler class to a nested key of the parser.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``.
|
lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``. Use
|
||||||
|
tuple to allow subclasses.
|
||||||
nested_key: Name of the nested namespace to store arguments.
|
nested_key: Name of the nested namespace to store arguments.
|
||||||
link_to: Dot notation of a parser key to set arguments or AUTOMATIC.
|
link_to: Dot notation of a parser key to set arguments or AUTOMATIC.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -19,7 +19,14 @@ import pytest
|
||||||
import pytorch_lightning.loggers.base as logger_base
|
import pytorch_lightning.loggers.base as logger_base
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.core.module import LightningModule
|
from pytorch_lightning.core.module import LightningModule
|
||||||
from pytorch_lightning.utilities.cli import LightningCLI
|
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||||
|
from pytorch_lightning.utilities.cli import (
|
||||||
|
_deprecate_auto_registry_message,
|
||||||
|
_deprecate_registry_message,
|
||||||
|
CALLBACK_REGISTRY,
|
||||||
|
LightningCLI,
|
||||||
|
SaveConfigCallback,
|
||||||
|
)
|
||||||
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
||||||
|
|
||||||
|
|
||||||
|
@ -133,3 +140,17 @@ def test_deprecated_dataloader_reset():
|
||||||
trainer = Trainer()
|
trainer = Trainer()
|
||||||
with pytest.deprecated_call(match="reset_train_val_dataloaders` has been deprecated in v1.7"):
|
with pytest.deprecated_call(match="reset_train_val_dataloaders` has been deprecated in v1.7"):
|
||||||
trainer.reset_train_val_dataloaders()
|
trainer.reset_train_val_dataloaders()
|
||||||
|
|
||||||
|
|
||||||
|
def test_lightningCLI_registries_register():
|
||||||
|
with pytest.deprecated_call(match=_deprecate_registry_message):
|
||||||
|
|
||||||
|
@CALLBACK_REGISTRY
|
||||||
|
class CustomCallback(SaveConfigCallback):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_lightningCLI_registries_register_automatically():
|
||||||
|
with pytest.deprecated_call(match=_deprecate_auto_registry_message):
|
||||||
|
with mock.patch("sys.argv", ["any.py"]):
|
||||||
|
LightningCLI(BoringModel, run=False, auto_registry=True)
|
||||||
|
|
|
@ -34,22 +34,15 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
|
||||||
from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, seed_everything, Trainer
|
from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, seed_everything, Trainer
|
||||||
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
||||||
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
|
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
|
||||||
from pytorch_lightning.loggers import Logger, TensorBoardLogger
|
from pytorch_lightning.loggers import TensorBoardLogger
|
||||||
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
from pytorch_lightning.plugins.environments import SLURMEnvironment
|
||||||
from pytorch_lightning.trainer.states import TrainerFn
|
from pytorch_lightning.trainer.states import TrainerFn
|
||||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
||||||
from pytorch_lightning.utilities.cli import (
|
from pytorch_lightning.utilities.cli import (
|
||||||
_populate_registries,
|
|
||||||
CALLBACK_REGISTRY,
|
|
||||||
DATAMODULE_REGISTRY,
|
|
||||||
instantiate_class,
|
instantiate_class,
|
||||||
LightningArgumentParser,
|
LightningArgumentParser,
|
||||||
LightningCLI,
|
LightningCLI,
|
||||||
LOGGER_REGISTRY,
|
|
||||||
LR_SCHEDULER_REGISTRY,
|
|
||||||
LRSchedulerTypeTuple,
|
LRSchedulerTypeTuple,
|
||||||
MODEL_REGISTRY,
|
|
||||||
OPTIMIZER_REGISTRY,
|
|
||||||
SaveConfigCallback,
|
SaveConfigCallback,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
|
@ -914,72 +907,6 @@ def test_lightning_cli_run():
|
||||||
assert isinstance(cli.model, LightningModule)
|
assert isinstance(cli.model, LightningModule)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def clear_registries():
|
|
||||||
# since the registries are global, it's good to clear them after each test to avoid unwanted interactions
|
|
||||||
yield
|
|
||||||
OPTIMIZER_REGISTRY.clear()
|
|
||||||
LR_SCHEDULER_REGISTRY.clear()
|
|
||||||
CALLBACK_REGISTRY.clear()
|
|
||||||
MODEL_REGISTRY.clear()
|
|
||||||
DATAMODULE_REGISTRY.clear()
|
|
||||||
LOGGER_REGISTRY.clear()
|
|
||||||
|
|
||||||
|
|
||||||
def test_registries():
|
|
||||||
# the registries are global so this is only necessary when this test is run standalone
|
|
||||||
_populate_registries(False)
|
|
||||||
|
|
||||||
@OPTIMIZER_REGISTRY
|
|
||||||
class CustomAdam(torch.optim.Adam):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@LR_SCHEDULER_REGISTRY
|
|
||||||
class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@CALLBACK_REGISTRY
|
|
||||||
class CustomCallback(Callback):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@LOGGER_REGISTRY
|
|
||||||
class CustomLogger(Logger):
|
|
||||||
pass
|
|
||||||
|
|
||||||
assert "SGD" in OPTIMIZER_REGISTRY.names
|
|
||||||
assert "RMSprop" in OPTIMIZER_REGISTRY.names
|
|
||||||
assert "CustomAdam" in OPTIMIZER_REGISTRY.names
|
|
||||||
|
|
||||||
assert "CosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names
|
|
||||||
assert "CosineAnnealingWarmRestarts" in LR_SCHEDULER_REGISTRY.names
|
|
||||||
assert "CustomCosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names
|
|
||||||
assert "ReduceLROnPlateau" in LR_SCHEDULER_REGISTRY.names
|
|
||||||
|
|
||||||
assert "EarlyStopping" in CALLBACK_REGISTRY.names
|
|
||||||
assert "CustomCallback" in CALLBACK_REGISTRY.names
|
|
||||||
|
|
||||||
class Foo:
|
|
||||||
...
|
|
||||||
|
|
||||||
OPTIMIZER_REGISTRY(Foo, key="SGD") # not overridden by default
|
|
||||||
assert OPTIMIZER_REGISTRY["SGD"] is torch.optim.SGD
|
|
||||||
OPTIMIZER_REGISTRY(Foo, key="SGD", override=True)
|
|
||||||
assert OPTIMIZER_REGISTRY["SGD"] is Foo
|
|
||||||
|
|
||||||
# test `_Registry.__call__` returns the class
|
|
||||||
assert isinstance(CustomCallback(), CustomCallback)
|
|
||||||
|
|
||||||
assert "WandbLogger" in LOGGER_REGISTRY
|
|
||||||
assert "CustomLogger" in LOGGER_REGISTRY
|
|
||||||
|
|
||||||
|
|
||||||
def test_registries_register_automatically():
|
|
||||||
assert "SaveConfigCallback" not in CALLBACK_REGISTRY
|
|
||||||
with mock.patch("sys.argv", ["any.py"]):
|
|
||||||
LightningCLI(BoringModel, run=False, auto_registry=True)
|
|
||||||
assert "SaveConfigCallback" in CALLBACK_REGISTRY
|
|
||||||
|
|
||||||
|
|
||||||
class TestModel(BoringModel):
|
class TestModel(BoringModel):
|
||||||
def __init__(self, foo, bar=5):
|
def __init__(self, foo, bar=5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -1137,11 +1064,11 @@ def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload
|
||||||
super().__init__(*args, run=False)
|
super().__init__(*args, run=False)
|
||||||
|
|
||||||
def add_arguments_to_parser(self, parser):
|
def add_arguments_to_parser(self, parser):
|
||||||
parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes, nested_key="opt1", link_to="model.opt1_config")
|
parser.add_optimizer_args(nested_key="opt1", link_to="model.opt1_config")
|
||||||
parser.add_optimizer_args(
|
parser.add_optimizer_args(
|
||||||
(torch.optim.ASGD, torch.optim.SGD), nested_key="opt2", link_to="model.opt2_config"
|
(torch.optim.ASGD, torch.optim.SGD), nested_key="opt2", link_to="model.opt2_config"
|
||||||
)
|
)
|
||||||
parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes, link_to="model.sch_config")
|
parser.add_lr_scheduler_args(link_to="model.sch_config")
|
||||||
parser.add_argument("--something", type=str, nargs="+")
|
parser.add_argument("--something", type=str, nargs="+")
|
||||||
|
|
||||||
class TestModel(BoringModel):
|
class TestModel(BoringModel):
|
||||||
|
|
Loading…
Reference in New Issue