Deprecate CLI registries and update documentation (#13221)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Mauricio Villegas 2022-06-21 10:12:04 -05:00 committed by GitHub
parent ad87d2cad0
commit 0ae9627bf8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 137 additions and 214 deletions

View File

@ -130,6 +130,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `Trainer.reset_train_val_dataloaders()` in favor of `Trainer.reset_{train,val}_dataloader` ([#12184](https://github.com/PyTorchLightning/pytorch-lightning/pull/12184)) - Deprecated `Trainer.reset_train_val_dataloaders()` in favor of `Trainer.reset_{train,val}_dataloader` ([#12184](https://github.com/PyTorchLightning/pytorch-lightning/pull/12184))
- Deprecated LightningCLI's registries in favor of importing the respective package ([#13221](https://github.com/PyTorchLightning/pytorch-lightning/pull/13221))
### Removed ### Removed
- Removed the deprecated `Logger.close` method ([#13149](https://github.com/PyTorchLightning/pytorch-lightning/pull/13149)) - Removed the deprecated `Logger.close` method ([#13149](https://github.com/PyTorchLightning/pytorch-lightning/pull/13149))

View File

@ -27,7 +27,7 @@ Basic use
.. displayitem:: .. displayitem::
:header: 2: Mix models and datasets :header: 2: Mix models and datasets
:description: Register models, datasets, optimizers and learning rate schedulers :description: Support multiple models, datasets, optimizers and learning rate schedulers
:col_css: col-md-4 :col_css: col-md-4
:button_link: lightning_cli_intermediate_2.html :button_link: lightning_cli_intermediate_2.html
:height: 150 :height: 150
@ -66,8 +66,8 @@ Advanced use
:tag: advanced :tag: advanced
.. displayitem:: .. displayitem::
:header: Customize configs for complex projects :header: Customize for complex projects
:description: Learn how to connect complex projects with each Registry. :description: Learn how to implement CLIs for complex projects.
:col_css: col-md-6 :col_css: col-md-6
:button_link: lightning_cli_advanced_3.html :button_link: lightning_cli_advanced_3.html
:height: 150 :height: 150

View File

@ -63,29 +63,6 @@ This can be useful to implement custom logic without having to subclass the CLI,
and argument parsing capabilities. and argument parsing capabilities.
Subclass registration
^^^^^^^^^^^^^^^^^^^^^
To use shorthand notation, the options need to be registered beforehand. This can be easily done with:
.. code-block::
LightningCLI(auto_registry=True) # False by default
which will register all subclasses of :class:`torch.optim.Optimizer`, :class:`torch.optim.lr_scheduler._LRScheduler`,
:class:`~pytorch_lightning.core.module.LightningModule`,
:class:`~pytorch_lightning.core.datamodule.LightningDataModule`, :class:`~pytorch_lightning.callbacks.Callback`, and
:class:`~pytorch_lightning.loggers.LightningLoggerBase` across all imported modules. This includes those in your own
code.
Alternatively, if this is left unset, only the subclasses defined in PyTorch's :class:`torch.optim.Optimizer`,
:class:`torch.optim.lr_scheduler._LRScheduler` and Lightning's :class:`~pytorch_lightning.callbacks.Callback` and
:class:`~pytorch_lightning.loggers.LightningLoggerBase` subclassess will be registered.
In subsequent sections, we will go over adding specific classes to specific registries as well as how to use
shorthand notation.
Trainer Callbacks and arguments with class type Trainer Callbacks and arguments with class type
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -107,14 +84,14 @@ file example that defines a couple of callbacks is the following:
init_args: init_args:
... ...
Similar to the callbacks, any arguments in :class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended Similar to the callbacks, any parameter in :class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended
:class:`~pytorch_lightning.core.module.LightningModule` and :class:`~pytorch_lightning.core.module.LightningModule` and
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class can be configured :class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class, can be
the same way using :code:`class_path` and :code:`init_args`. configured the same way using :code:`class_path` and :code:`init_args`. If the package that defines a subclass is
imported before the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class is run, the name can be used instead of
the full import path.
For callbacks in particular, Lightning simplifies the command line so that only From command line the syntax is the following:
the :class:`~pytorch_lightning.callbacks.Callback` name is required.
The argument's order matters and the user needs to pass the arguments in the following way.
.. code-block:: bash .. code-block:: bash
@ -127,7 +104,8 @@ The argument's order matters and the user needs to pass the arguments in the fol
--trainer.callbacks.{CALLBACK_N_ARGS_1}=... \ --trainer.callbacks.{CALLBACK_N_ARGS_1}=... \
... ...
Here is an example: Note the use of ``+`` to append a new callback to the list and that the ``init_args`` are applied to the previous
callback appended. Here is an example:
.. code-block:: bash .. code-block:: bash
@ -137,43 +115,11 @@ Here is an example:
--trainer.callbacks+=LearningRateMonitor \ --trainer.callbacks+=LearningRateMonitor \
--trainer.callbacks.logging_interval=epoch --trainer.callbacks.logging_interval=epoch
Lightning provides a mechanism for you to add your own callbacks and benefit from the command line simplification
as described above:
.. code-block:: python
from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY
@CALLBACK_REGISTRY
class CustomCallback(Callback):
...
cli = LightningCLI(...)
.. code-block:: bash
$ python ... --trainer.callbacks+=CustomCallback ...
.. note:: .. note::
This shorthand notation is also supported inside a configuration file. The configuration file Serialized config files (e.g. ``--print_config`` or :class:`~pytorch_lightning.utilities.cli.SaveConfigCallback`)
generated by calling the previous command with ``--print_config`` will have the full ``class_path`` notation. always have the full ``class_path``'s, even when class name shorthand notation is used in command line or in input
config files.
.. code-block:: yaml
trainer:
callbacks:
- class_path: your_class_path.CustomCallback
init_args:
...
.. tip::
``--trainer.logger`` also supports shorthand notation and a ``LOGGER_REGISTRY`` is available to register custom
Loggers.
Multiple models and/or datasets Multiple models and/or datasets
@ -377,12 +323,8 @@ example can be when one wants to add support for multiple optimizers:
class MyLightningCLI(LightningCLI): class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser): def add_arguments_to_parser(self, parser):
parser.add_optimizer_args( parser.add_optimizer_args(nested_key="optimizer1", link_to="model.optimizer1_init")
OPTIMIZER_REGISTRY.classes, nested_key="gen_optimizer", link_to="model.optimizer1_init" parser.add_optimizer_args(nested_key="optimizer2", link_to="model.optimizer2_init")
)
parser.add_optimizer_args(
OPTIMIZER_REGISTRY.classes, nested_key="gen_discriminator", link_to="model.optimizer2_init"
)
cli = MyLightningCLI(MyModel) cli = MyLightningCLI(MyModel)
@ -398,18 +340,17 @@ With shorthand notation:
.. code-block:: bash .. code-block:: bash
$ python trainer.py fit \ $ python trainer.py fit \
--gen_optimizer=Adam \ --optimizer1=Adam \
--gen_optimizer.lr=0.01 \ --optimizer1.lr=0.01 \
--gen_discriminator=AdamW \ --optimizer2=AdamW \
--gen_discriminator.lr=0.0001 --optimizer2.lr=0.0001
You can also pass the class path directly, for example, if the optimizer hasn't been registered to the You can also pass the class path directly, for example, if the optimizer hasn't been imported:
``OPTIMIZER_REGISTRY``:
.. code-block:: bash .. code-block:: bash
$ python trainer.py fit \ $ python trainer.py fit \
--gen_optimizer.class_path=torch.optim.Adam \ --optimizer1=torch.optim.Adam \
--gen_optimizer.init_args.lr=0.01 \ --optimizer1.lr=0.01 \
--gen_discriminator.class_path=torch.optim.AdamW \ --optimizer2=torch.optim.AdamW \
--gen_discriminator.init_args.lr=0.0001 --optimizer2.lr=0.0001

View File

@ -46,9 +46,9 @@ This is what the Lightning CLI enables. Otherwise, this kind of configuration re
---- ----
************************* *************************
Register LightningModules Multiple LightningModules
************************* *************************
Connect models across different files with the ``MODEL_REGISTRY`` to make them available from the CLI: To support multiple models, when instantiating ``LightningCLI`` omit the ``model_class`` parameter:
.. code:: python .. code:: python
@ -58,14 +58,12 @@ Connect models across different files with the ``MODEL_REGISTRY`` to make them a
from pytorch_lightning.utilities import cli as pl_cli from pytorch_lightning.utilities import cli as pl_cli
@pl_cli.MODEL_REGISTRY
class Model1(DemoModel): class Model1(DemoModel):
def configure_optimizers(self): def configure_optimizers(self):
print("⚡", "using Model1", "⚡") print("⚡", "using Model1", "⚡")
return super().configure_optimizers() return super().configure_optimizers()
@pl_cli.MODEL_REGISTRY
class Model2(DemoModel): class Model2(DemoModel):
def configure_optimizers(self): def configure_optimizers(self):
print("⚡", "using Model2", "⚡") print("⚡", "using Model2", "⚡")
@ -87,9 +85,9 @@ Now you can choose between any model from the CLI:
---- ----
******************** ********************
Register DataModules Multiple DataModules
******************** ********************
Connect DataModules across different files with the ``DATAMODULE_REGISTRY`` to make them available from the CLI: To support multiple data modules, when instantiating ``LightningCLI`` omit the ``datamodule_class`` parameter:
.. code:: python .. code:: python
@ -99,14 +97,12 @@ Connect DataModules across different files with the ``DATAMODULE_REGISTRY`` to m
from pytorch_lightning import demos from pytorch_lightning import demos
@pl_cli.DATAMODULE_REGISTRY
class FakeDataset1(BoringDataModule): class FakeDataset1(BoringDataModule):
def train_dataloader(self): def train_dataloader(self):
print("⚡", "using FakeDataset1", "⚡") print("⚡", "using FakeDataset1", "⚡")
return torch.utils.data.DataLoader(self.random_train) return torch.utils.data.DataLoader(self.random_train)
@pl_cli.DATAMODULE_REGISTRY
class FakeDataset2(BoringDataModule): class FakeDataset2(BoringDataModule):
def train_dataloader(self): def train_dataloader(self):
print("⚡", "using FakeDataset2", "⚡") print("⚡", "using FakeDataset2", "⚡")
@ -127,10 +123,10 @@ Now you can choose between any dataset at runtime:
---- ----
******************* *****************
Register optimizers Custom optimizers
******************* *****************
Connect optimizers with the ``OPTIMIZER_REGISTRY`` to make them available from the CLI: Any subclass of ``torch.optim.Optimizer`` can be used as an optimizer:
.. code:: python .. code:: python
@ -140,14 +136,12 @@ Connect optimizers with the ``OPTIMIZER_REGISTRY`` to make them available from t
from pytorch_lightning import demos from pytorch_lightning import demos
@pl_cli.OPTIMIZER_REGISTRY
class LitAdam(torch.optim.Adam): class LitAdam(torch.optim.Adam):
def step(self, closure): def step(self, closure):
print("⚡", "using LitAdam", "⚡") print("⚡", "using LitAdam", "⚡")
super().step(closure) super().step(closure)
@pl_cli.OPTIMIZER_REGISTRY
class FancyAdam(torch.optim.Adam): class FancyAdam(torch.optim.Adam):
def step(self, closure): def step(self, closure):
print("⚡", "using FancyAdam", "⚡") print("⚡", "using FancyAdam", "⚡")
@ -166,7 +160,8 @@ Now you can choose between any optimizer at runtime:
# use FancyAdam # use FancyAdam
python main.py fit --optimizer FancyAdam python main.py fit --optimizer FancyAdam
Bonus: If you need only 1 optimizer, the Lightning CLI already works out of the box with any Optimizer from ``torch.optim.optim``: Bonus: If you need only 1 optimizer, the Lightning CLI already works out of the box with any Optimizer from
``torch.optim``:
.. code:: bash .. code:: bash
@ -180,10 +175,10 @@ If the optimizer you want needs other arguments, add them via the CLI (no need t
---- ----
********************** ********************
Register LR schedulers Custom LR schedulers
********************** ********************
Connect learning rate schedulers with the ``LR_SCHEDULER_REGISTRY`` to make them available from the CLI: Any subclass of ``torch.optim.lr_scheduler._LRScheduler`` can be used as learning rate scheduler:
.. code:: python .. code:: python
@ -193,7 +188,6 @@ Connect learning rate schedulers with the ``LR_SCHEDULER_REGISTRY`` to make them
from pytorch_lightning import demos from pytorch_lightning import demos
@pl_cli.LR_SCHEDULER_REGISTRY
class LitLRScheduler(torch.optim.lr_scheduler.CosineAnnealingLR): class LitLRScheduler(torch.optim.lr_scheduler.CosineAnnealingLR):
def step(self): def step(self):
print("⚡", "using LitLRScheduler", "⚡") print("⚡", "using LitLRScheduler", "⚡")
@ -210,7 +204,8 @@ Now you can choose between any learning rate scheduler at runtime:
python main.py fit --lr_scheduler LitLRScheduler python main.py fit --lr_scheduler LitLRScheduler
Bonus: If you need only 1 LRScheduler, the Lightning CLI already works out of the box with any LRScheduler from ``torch.optim``: Bonus: If you need only 1 LRScheduler, the Lightning CLI already works out of the box with any LRScheduler from
``torch.optim``:
.. code:: bash .. code:: bash
@ -226,26 +221,31 @@ If the scheduler you want needs other arguments, add them via the CLI (no need t
---- ----
************************* ************************
Register from any package Classes from any package
************************* ************************
A shortcut to register many classes from a package is to use the ``register_classes`` method. Here we register all optimizers from the ``torch.optim`` library: In the previous sections the classes to select were defined in the same python file where the ``LightningCLI`` class is
run. To select classes from any package by using only the class name, import the respective package:
.. code:: python .. code:: python
import torch import torch
from pytorch_lightning.utilities import cli as pl_cli from pytorch_lightning.utilities import cli as pl_cli
from pytorch_lightning import demos import my_code.models # noqa: F401
import my_code.data_modules # noqa: F401
import my_code.optimizers # noqa: F401
# add all PyTorch optimizers! cli = pl_cli.LightningCLI()
pl_cli.OPTIMIZER_REGISTRY.register_classes(module=torch.optim, base_cls=torch.optim.Optimizer)
cli = pl_cli.LightningCLI(DemoModel, BoringDataModule) Now use any of the classes:
Now use any of the optimizers in the ``torch.optim`` library:
.. code:: bash .. code:: bash
python main.py fit --optimizer AdamW python main.py fit --model Model1 --data FakeDataset1 --optimizer LitAdam --lr_scheduler LitLRScheduler
This method is supported by all the registry classes. The ``# noqa: F401`` comment avoids a linter warning that the import is unused. It is also possible to select subclasses
that have not been imported by giving the full import path:
.. code:: bash
python main.py fit --model my_code.models.Model1

View File

@ -51,8 +51,23 @@ else:
locals()["Namespace"] = object locals()["Namespace"] = object
class _Registry(dict): _deprecate_registry_message = (
def __call__(self, cls: Type, key: Optional[str] = None, override: bool = False) -> Type: "`LightningCLI`'s registries were deprecated in v1.7 and will be removed "
"in v1.9. Now any imported subclass is automatically available by name in "
"`LightningCLI` without any need to explicitly register it."
)
_deprecate_auto_registry_message = (
"`LightningCLI.auto_registry` parameter was deprecated in v1.7 and will be removed "
"in v1.9. Now any imported subclass is automatically available by name in "
"`LightningCLI` without any need to explicitly register it."
)
class _Registry(dict): # Remove in v1.9
def __call__(
self, cls: Type, key: Optional[str] = None, override: bool = False, show_deprecation: bool = True
) -> Type:
"""Registers a class mapped to a name. """Registers a class mapped to a name.
Args: Args:
@ -67,12 +82,16 @@ class _Registry(dict):
if key not in self or override: if key not in self or override:
self[key] = cls self[key] = cls
self._deprecation(show_deprecation)
return cls return cls
def register_classes(self, module: ModuleType, base_cls: Type, override: bool = False) -> None: def register_classes(
self, module: ModuleType, base_cls: Type, override: bool = False, show_deprecation: bool = True
) -> None:
"""This function is an utility to register all classes from a module.""" """This function is an utility to register all classes from a module."""
for cls in self.get_members(module, base_cls): for cls in self.get_members(module, base_cls):
self(cls=cls, override=override) self(cls=cls, override=override, show_deprecation=show_deprecation)
@staticmethod @staticmethod
def get_members(module: ModuleType, base_cls: Type) -> Generator[Type, None, None]: def get_members(module: ModuleType, base_cls: Type) -> Generator[Type, None, None]:
@ -85,16 +104,23 @@ class _Registry(dict):
@property @property
def names(self) -> List[str]: def names(self) -> List[str]:
"""Returns the registered names.""" """Returns the registered names."""
self._deprecation()
return list(self.keys()) return list(self.keys())
@property @property
def classes(self) -> Tuple[Type, ...]: def classes(self) -> Tuple[Type, ...]:
"""Returns the registered classes.""" """Returns the registered classes."""
self._deprecation()
return tuple(self.values()) return tuple(self.values())
def __str__(self) -> str: def __str__(self) -> str:
return f"Registered objects: {self.names}" return f"Registered objects: {self.names}"
def _deprecation(self, show_deprecation: bool = True) -> None:
if show_deprecation and not getattr(self, "deprecation_shown", False):
rank_zero_deprecation(_deprecate_registry_message)
self.deprecation_shown = True
OPTIMIZER_REGISTRY = _Registry() OPTIMIZER_REGISTRY = _Registry()
LR_SCHEDULER_REGISTRY = _Registry() LR_SCHEDULER_REGISTRY = _Registry()
@ -116,29 +142,32 @@ LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, ReduceLROnPl
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[ReduceLROnPlateau]] LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[ReduceLROnPlateau]]
def _populate_registries(subclasses: bool) -> None: def _populate_registries(subclasses: bool) -> None: # Remove in v1.9
if subclasses: if subclasses:
rank_zero_deprecation(_deprecate_auto_registry_message)
# this will register any subclasses from all loaded modules including userland # this will register any subclasses from all loaded modules including userland
for cls in get_all_subclasses(torch.optim.Optimizer): for cls in get_all_subclasses(torch.optim.Optimizer):
OPTIMIZER_REGISTRY(cls) OPTIMIZER_REGISTRY(cls, show_deprecation=False)
for cls in get_all_subclasses(torch.optim.lr_scheduler._LRScheduler): for cls in get_all_subclasses(torch.optim.lr_scheduler._LRScheduler):
LR_SCHEDULER_REGISTRY(cls) LR_SCHEDULER_REGISTRY(cls, show_deprecation=False)
for cls in get_all_subclasses(pl.Callback): for cls in get_all_subclasses(pl.Callback):
CALLBACK_REGISTRY(cls) CALLBACK_REGISTRY(cls, show_deprecation=False)
for cls in get_all_subclasses(pl.LightningModule): for cls in get_all_subclasses(pl.LightningModule):
MODEL_REGISTRY(cls) MODEL_REGISTRY(cls, show_deprecation=False)
for cls in get_all_subclasses(pl.LightningDataModule): for cls in get_all_subclasses(pl.LightningDataModule):
DATAMODULE_REGISTRY(cls) DATAMODULE_REGISTRY(cls, show_deprecation=False)
for cls in get_all_subclasses(pl.loggers.Logger): for cls in get_all_subclasses(pl.loggers.Logger):
LOGGER_REGISTRY(cls) LOGGER_REGISTRY(cls, show_deprecation=False)
else: else:
# manually register torch's subclasses and our subclasses # manually register torch's subclasses and our subclasses
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer) OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer, show_deprecation=False)
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) LR_SCHEDULER_REGISTRY.register_classes(
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.Callback) torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler, show_deprecation=False
LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.Logger) )
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.Callback, show_deprecation=False)
LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.Logger, show_deprecation=False)
# `ReduceLROnPlateau` does not subclass `_LRScheduler` # `ReduceLROnPlateau` does not subclass `_LRScheduler`
LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau) LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau, show_deprecation=False)
class LightningArgumentParser(ArgumentParser): class LightningArgumentParser(ArgumentParser):
@ -211,14 +240,14 @@ class LightningArgumentParser(ArgumentParser):
def add_optimizer_args( def add_optimizer_args(
self, self,
optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]], optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]] = (Optimizer,),
nested_key: str = "optimizer", nested_key: str = "optimizer",
link_to: str = "AUTOMATIC", link_to: str = "AUTOMATIC",
) -> None: ) -> None:
"""Adds arguments from an optimizer class to a nested key of the parser. """Adds arguments from an optimizer class to a nested key of the parser.
Args: Args:
optimizer_class: Any subclass of :class:`torch.optim.Optimizer`. optimizer_class: Any subclass of :class:`torch.optim.Optimizer`. Use tuple to allow subclasses.
nested_key: Name of the nested namespace to store arguments. nested_key: Name of the nested namespace to store arguments.
link_to: Dot notation of a parser key to set arguments or AUTOMATIC. link_to: Dot notation of a parser key to set arguments or AUTOMATIC.
""" """
@ -235,14 +264,15 @@ class LightningArgumentParser(ArgumentParser):
def add_lr_scheduler_args( def add_lr_scheduler_args(
self, self,
lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]], lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]] = LRSchedulerTypeTuple,
nested_key: str = "lr_scheduler", nested_key: str = "lr_scheduler",
link_to: str = "AUTOMATIC", link_to: str = "AUTOMATIC",
) -> None: ) -> None:
"""Adds arguments from a learning rate scheduler class to a nested key of the parser. """Adds arguments from a learning rate scheduler class to a nested key of the parser.
Args: Args:
lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``. lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``. Use
tuple to allow subclasses.
nested_key: Name of the nested namespace to store arguments. nested_key: Name of the nested namespace to store arguments.
link_to: Dot notation of a parser key to set arguments or AUTOMATIC. link_to: Dot notation of a parser key to set arguments or AUTOMATIC.
""" """

View File

@ -19,7 +19,14 @@ import pytest
import pytorch_lightning.loggers.base as logger_base import pytorch_lightning.loggers.base as logger_base
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.core.module import LightningModule from pytorch_lightning.core.module import LightningModule
from pytorch_lightning.utilities.cli import LightningCLI from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.utilities.cli import (
_deprecate_auto_registry_message,
_deprecate_registry_message,
CALLBACK_REGISTRY,
LightningCLI,
SaveConfigCallback,
)
from pytorch_lightning.utilities.rank_zero import rank_zero_only from pytorch_lightning.utilities.rank_zero import rank_zero_only
@ -133,3 +140,17 @@ def test_deprecated_dataloader_reset():
trainer = Trainer() trainer = Trainer()
with pytest.deprecated_call(match="reset_train_val_dataloaders` has been deprecated in v1.7"): with pytest.deprecated_call(match="reset_train_val_dataloaders` has been deprecated in v1.7"):
trainer.reset_train_val_dataloaders() trainer.reset_train_val_dataloaders()
def test_lightningCLI_registries_register():
with pytest.deprecated_call(match=_deprecate_registry_message):
@CALLBACK_REGISTRY
class CustomCallback(SaveConfigCallback):
pass
def test_lightningCLI_registries_register_automatically():
with pytest.deprecated_call(match=_deprecate_auto_registry_message):
with mock.patch("sys.argv", ["any.py"]):
LightningCLI(BoringModel, run=False, auto_registry=True)

View File

@ -34,22 +34,15 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, seed_everything, Trainer from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
from pytorch_lightning.loggers import Logger, TensorBoardLogger from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.cli import ( from pytorch_lightning.utilities.cli import (
_populate_registries,
CALLBACK_REGISTRY,
DATAMODULE_REGISTRY,
instantiate_class, instantiate_class,
LightningArgumentParser, LightningArgumentParser,
LightningCLI, LightningCLI,
LOGGER_REGISTRY,
LR_SCHEDULER_REGISTRY,
LRSchedulerTypeTuple, LRSchedulerTypeTuple,
MODEL_REGISTRY,
OPTIMIZER_REGISTRY,
SaveConfigCallback, SaveConfigCallback,
) )
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -914,72 +907,6 @@ def test_lightning_cli_run():
assert isinstance(cli.model, LightningModule) assert isinstance(cli.model, LightningModule)
@pytest.fixture(autouse=True)
def clear_registries():
# since the registries are global, it's good to clear them after each test to avoid unwanted interactions
yield
OPTIMIZER_REGISTRY.clear()
LR_SCHEDULER_REGISTRY.clear()
CALLBACK_REGISTRY.clear()
MODEL_REGISTRY.clear()
DATAMODULE_REGISTRY.clear()
LOGGER_REGISTRY.clear()
def test_registries():
# the registries are global so this is only necessary when this test is run standalone
_populate_registries(False)
@OPTIMIZER_REGISTRY
class CustomAdam(torch.optim.Adam):
pass
@LR_SCHEDULER_REGISTRY
class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
pass
@CALLBACK_REGISTRY
class CustomCallback(Callback):
pass
@LOGGER_REGISTRY
class CustomLogger(Logger):
pass
assert "SGD" in OPTIMIZER_REGISTRY.names
assert "RMSprop" in OPTIMIZER_REGISTRY.names
assert "CustomAdam" in OPTIMIZER_REGISTRY.names
assert "CosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names
assert "CosineAnnealingWarmRestarts" in LR_SCHEDULER_REGISTRY.names
assert "CustomCosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names
assert "ReduceLROnPlateau" in LR_SCHEDULER_REGISTRY.names
assert "EarlyStopping" in CALLBACK_REGISTRY.names
assert "CustomCallback" in CALLBACK_REGISTRY.names
class Foo:
...
OPTIMIZER_REGISTRY(Foo, key="SGD") # not overridden by default
assert OPTIMIZER_REGISTRY["SGD"] is torch.optim.SGD
OPTIMIZER_REGISTRY(Foo, key="SGD", override=True)
assert OPTIMIZER_REGISTRY["SGD"] is Foo
# test `_Registry.__call__` returns the class
assert isinstance(CustomCallback(), CustomCallback)
assert "WandbLogger" in LOGGER_REGISTRY
assert "CustomLogger" in LOGGER_REGISTRY
def test_registries_register_automatically():
assert "SaveConfigCallback" not in CALLBACK_REGISTRY
with mock.patch("sys.argv", ["any.py"]):
LightningCLI(BoringModel, run=False, auto_registry=True)
assert "SaveConfigCallback" in CALLBACK_REGISTRY
class TestModel(BoringModel): class TestModel(BoringModel):
def __init__(self, foo, bar=5): def __init__(self, foo, bar=5):
super().__init__() super().__init__()
@ -1137,11 +1064,11 @@ def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload
super().__init__(*args, run=False) super().__init__(*args, run=False)
def add_arguments_to_parser(self, parser): def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes, nested_key="opt1", link_to="model.opt1_config") parser.add_optimizer_args(nested_key="opt1", link_to="model.opt1_config")
parser.add_optimizer_args( parser.add_optimizer_args(
(torch.optim.ASGD, torch.optim.SGD), nested_key="opt2", link_to="model.opt2_config" (torch.optim.ASGD, torch.optim.SGD), nested_key="opt2", link_to="model.opt2_config"
) )
parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes, link_to="model.sch_config") parser.add_lr_scheduler_args(link_to="model.sch_config")
parser.add_argument("--something", type=str, nargs="+") parser.add_argument("--something", type=str, nargs="+")
class TestModel(BoringModel): class TestModel(BoringModel):