diff --git a/CHANGELOG.md b/CHANGELOG.md index d2590a70e0..39f0918e03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -130,6 +130,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `Trainer.reset_train_val_dataloaders()` in favor of `Trainer.reset_{train,val}_dataloader` ([#12184](https://github.com/PyTorchLightning/pytorch-lightning/pull/12184)) + +- Deprecated LightningCLI's registries in favor of importing the respective package ([#13221](https://github.com/PyTorchLightning/pytorch-lightning/pull/13221)) + + ### Removed - Removed the deprecated `Logger.close` method ([#13149](https://github.com/PyTorchLightning/pytorch-lightning/pull/13149)) diff --git a/docs/source-pytorch/cli/lightning_cli.rst b/docs/source-pytorch/cli/lightning_cli.rst index 76f3f12140..a0933b447a 100644 --- a/docs/source-pytorch/cli/lightning_cli.rst +++ b/docs/source-pytorch/cli/lightning_cli.rst @@ -27,7 +27,7 @@ Basic use .. displayitem:: :header: 2: Mix models and datasets - :description: Register models, datasets, optimizers and learning rate schedulers + :description: Support multiple models, datasets, optimizers and learning rate schedulers :col_css: col-md-4 :button_link: lightning_cli_intermediate_2.html :height: 150 @@ -66,8 +66,8 @@ Advanced use :tag: advanced .. displayitem:: - :header: Customize configs for complex projects - :description: Learn how to connect complex projects with each Registry. + :header: Customize for complex projects + :description: Learn how to implement CLIs for complex projects. :col_css: col-md-6 :button_link: lightning_cli_advanced_3.html :height: 150 diff --git a/docs/source-pytorch/cli/lightning_cli_advanced_3.rst b/docs/source-pytorch/cli/lightning_cli_advanced_3.rst index 6ecc43bed7..0e9c3f406d 100644 --- a/docs/source-pytorch/cli/lightning_cli_advanced_3.rst +++ b/docs/source-pytorch/cli/lightning_cli_advanced_3.rst @@ -63,29 +63,6 @@ This can be useful to implement custom logic without having to subclass the CLI, and argument parsing capabilities. -Subclass registration -^^^^^^^^^^^^^^^^^^^^^ - -To use shorthand notation, the options need to be registered beforehand. This can be easily done with: - -.. code-block:: - - LightningCLI(auto_registry=True) # False by default - -which will register all subclasses of :class:`torch.optim.Optimizer`, :class:`torch.optim.lr_scheduler._LRScheduler`, -:class:`~pytorch_lightning.core.module.LightningModule`, -:class:`~pytorch_lightning.core.datamodule.LightningDataModule`, :class:`~pytorch_lightning.callbacks.Callback`, and -:class:`~pytorch_lightning.loggers.LightningLoggerBase` across all imported modules. This includes those in your own -code. - -Alternatively, if this is left unset, only the subclasses defined in PyTorch's :class:`torch.optim.Optimizer`, -:class:`torch.optim.lr_scheduler._LRScheduler` and Lightning's :class:`~pytorch_lightning.callbacks.Callback` and -:class:`~pytorch_lightning.loggers.LightningLoggerBase` subclassess will be registered. - -In subsequent sections, we will go over adding specific classes to specific registries as well as how to use -shorthand notation. - - Trainer Callbacks and arguments with class type ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -107,14 +84,14 @@ file example that defines a couple of callbacks is the following: init_args: ... -Similar to the callbacks, any arguments in :class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended +Similar to the callbacks, any parameter in :class:`~pytorch_lightning.trainer.trainer.Trainer` and user extended :class:`~pytorch_lightning.core.module.LightningModule` and -:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class can be configured -the same way using :code:`class_path` and :code:`init_args`. +:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class, can be +configured the same way using :code:`class_path` and :code:`init_args`. If the package that defines a subclass is +imported before the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class is run, the name can be used instead of +the full import path. -For callbacks in particular, Lightning simplifies the command line so that only -the :class:`~pytorch_lightning.callbacks.Callback` name is required. -The argument's order matters and the user needs to pass the arguments in the following way. +From command line the syntax is the following: .. code-block:: bash @@ -127,7 +104,8 @@ The argument's order matters and the user needs to pass the arguments in the fol --trainer.callbacks.{CALLBACK_N_ARGS_1}=... \ ... -Here is an example: +Note the use of ``+`` to append a new callback to the list and that the ``init_args`` are applied to the previous +callback appended. Here is an example: .. code-block:: bash @@ -137,43 +115,11 @@ Here is an example: --trainer.callbacks+=LearningRateMonitor \ --trainer.callbacks.logging_interval=epoch -Lightning provides a mechanism for you to add your own callbacks and benefit from the command line simplification -as described above: - -.. code-block:: python - - from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY - - - @CALLBACK_REGISTRY - class CustomCallback(Callback): - ... - - - cli = LightningCLI(...) - -.. code-block:: bash - - $ python ... --trainer.callbacks+=CustomCallback ... - .. note:: - This shorthand notation is also supported inside a configuration file. The configuration file - generated by calling the previous command with ``--print_config`` will have the full ``class_path`` notation. - - .. code-block:: yaml - - trainer: - callbacks: - - class_path: your_class_path.CustomCallback - init_args: - ... - - -.. tip:: - - ``--trainer.logger`` also supports shorthand notation and a ``LOGGER_REGISTRY`` is available to register custom - Loggers. + Serialized config files (e.g. ``--print_config`` or :class:`~pytorch_lightning.utilities.cli.SaveConfigCallback`) + always have the full ``class_path``'s, even when class name shorthand notation is used in command line or in input + config files. Multiple models and/or datasets @@ -377,12 +323,8 @@ example can be when one wants to add support for multiple optimizers: class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): - parser.add_optimizer_args( - OPTIMIZER_REGISTRY.classes, nested_key="gen_optimizer", link_to="model.optimizer1_init" - ) - parser.add_optimizer_args( - OPTIMIZER_REGISTRY.classes, nested_key="gen_discriminator", link_to="model.optimizer2_init" - ) + parser.add_optimizer_args(nested_key="optimizer1", link_to="model.optimizer1_init") + parser.add_optimizer_args(nested_key="optimizer2", link_to="model.optimizer2_init") cli = MyLightningCLI(MyModel) @@ -398,18 +340,17 @@ With shorthand notation: .. code-block:: bash $ python trainer.py fit \ - --gen_optimizer=Adam \ - --gen_optimizer.lr=0.01 \ - --gen_discriminator=AdamW \ - --gen_discriminator.lr=0.0001 + --optimizer1=Adam \ + --optimizer1.lr=0.01 \ + --optimizer2=AdamW \ + --optimizer2.lr=0.0001 -You can also pass the class path directly, for example, if the optimizer hasn't been registered to the -``OPTIMIZER_REGISTRY``: +You can also pass the class path directly, for example, if the optimizer hasn't been imported: .. code-block:: bash $ python trainer.py fit \ - --gen_optimizer.class_path=torch.optim.Adam \ - --gen_optimizer.init_args.lr=0.01 \ - --gen_discriminator.class_path=torch.optim.AdamW \ - --gen_discriminator.init_args.lr=0.0001 + --optimizer1=torch.optim.Adam \ + --optimizer1.lr=0.01 \ + --optimizer2=torch.optim.AdamW \ + --optimizer2.lr=0.0001 diff --git a/docs/source-pytorch/cli/lightning_cli_intermediate_2.rst b/docs/source-pytorch/cli/lightning_cli_intermediate_2.rst index 8d0a5e0910..8e312b7233 100644 --- a/docs/source-pytorch/cli/lightning_cli_intermediate_2.rst +++ b/docs/source-pytorch/cli/lightning_cli_intermediate_2.rst @@ -46,9 +46,9 @@ This is what the Lightning CLI enables. Otherwise, this kind of configuration re ---- ************************* -Register LightningModules +Multiple LightningModules ************************* -Connect models across different files with the ``MODEL_REGISTRY`` to make them available from the CLI: +To support multiple models, when instantiating ``LightningCLI`` omit the ``model_class`` parameter: .. code:: python @@ -58,14 +58,12 @@ Connect models across different files with the ``MODEL_REGISTRY`` to make them a from pytorch_lightning.utilities import cli as pl_cli - @pl_cli.MODEL_REGISTRY class Model1(DemoModel): def configure_optimizers(self): print("⚡", "using Model1", "⚡") return super().configure_optimizers() - @pl_cli.MODEL_REGISTRY class Model2(DemoModel): def configure_optimizers(self): print("⚡", "using Model2", "⚡") @@ -87,9 +85,9 @@ Now you can choose between any model from the CLI: ---- ******************** -Register DataModules +Multiple DataModules ******************** -Connect DataModules across different files with the ``DATAMODULE_REGISTRY`` to make them available from the CLI: +To support multiple data modules, when instantiating ``LightningCLI`` omit the ``datamodule_class`` parameter: .. code:: python @@ -99,14 +97,12 @@ Connect DataModules across different files with the ``DATAMODULE_REGISTRY`` to m from pytorch_lightning import demos - @pl_cli.DATAMODULE_REGISTRY class FakeDataset1(BoringDataModule): def train_dataloader(self): print("⚡", "using FakeDataset1", "⚡") return torch.utils.data.DataLoader(self.random_train) - @pl_cli.DATAMODULE_REGISTRY class FakeDataset2(BoringDataModule): def train_dataloader(self): print("⚡", "using FakeDataset2", "⚡") @@ -127,10 +123,10 @@ Now you can choose between any dataset at runtime: ---- -******************* -Register optimizers -******************* -Connect optimizers with the ``OPTIMIZER_REGISTRY`` to make them available from the CLI: +***************** +Custom optimizers +***************** +Any subclass of ``torch.optim.Optimizer`` can be used as an optimizer: .. code:: python @@ -140,14 +136,12 @@ Connect optimizers with the ``OPTIMIZER_REGISTRY`` to make them available from t from pytorch_lightning import demos - @pl_cli.OPTIMIZER_REGISTRY class LitAdam(torch.optim.Adam): def step(self, closure): print("⚡", "using LitAdam", "⚡") super().step(closure) - @pl_cli.OPTIMIZER_REGISTRY class FancyAdam(torch.optim.Adam): def step(self, closure): print("⚡", "using FancyAdam", "⚡") @@ -166,7 +160,8 @@ Now you can choose between any optimizer at runtime: # use FancyAdam python main.py fit --optimizer FancyAdam -Bonus: If you need only 1 optimizer, the Lightning CLI already works out of the box with any Optimizer from ``torch.optim.optim``: +Bonus: If you need only 1 optimizer, the Lightning CLI already works out of the box with any Optimizer from +``torch.optim``: .. code:: bash @@ -180,10 +175,10 @@ If the optimizer you want needs other arguments, add them via the CLI (no need t ---- -********************** -Register LR schedulers -********************** -Connect learning rate schedulers with the ``LR_SCHEDULER_REGISTRY`` to make them available from the CLI: +******************** +Custom LR schedulers +******************** +Any subclass of ``torch.optim.lr_scheduler._LRScheduler`` can be used as learning rate scheduler: .. code:: python @@ -193,7 +188,6 @@ Connect learning rate schedulers with the ``LR_SCHEDULER_REGISTRY`` to make them from pytorch_lightning import demos - @pl_cli.LR_SCHEDULER_REGISTRY class LitLRScheduler(torch.optim.lr_scheduler.CosineAnnealingLR): def step(self): print("⚡", "using LitLRScheduler", "⚡") @@ -210,7 +204,8 @@ Now you can choose between any learning rate scheduler at runtime: python main.py fit --lr_scheduler LitLRScheduler -Bonus: If you need only 1 LRScheduler, the Lightning CLI already works out of the box with any LRScheduler from ``torch.optim``: +Bonus: If you need only 1 LRScheduler, the Lightning CLI already works out of the box with any LRScheduler from +``torch.optim``: .. code:: bash @@ -226,26 +221,31 @@ If the scheduler you want needs other arguments, add them via the CLI (no need t ---- -************************* -Register from any package -************************* -A shortcut to register many classes from a package is to use the ``register_classes`` method. Here we register all optimizers from the ``torch.optim`` library: +************************ +Classes from any package +************************ +In the previous sections the classes to select were defined in the same python file where the ``LightningCLI`` class is +run. To select classes from any package by using only the class name, import the respective package: .. code:: python import torch from pytorch_lightning.utilities import cli as pl_cli - from pytorch_lightning import demos + import my_code.models # noqa: F401 + import my_code.data_modules # noqa: F401 + import my_code.optimizers # noqa: F401 - # add all PyTorch optimizers! - pl_cli.OPTIMIZER_REGISTRY.register_classes(module=torch.optim, base_cls=torch.optim.Optimizer) + cli = pl_cli.LightningCLI() - cli = pl_cli.LightningCLI(DemoModel, BoringDataModule) - -Now use any of the optimizers in the ``torch.optim`` library: +Now use any of the classes: .. code:: bash - python main.py fit --optimizer AdamW + python main.py fit --model Model1 --data FakeDataset1 --optimizer LitAdam --lr_scheduler LitLRScheduler -This method is supported by all the registry classes. +The ``# noqa: F401`` comment avoids a linter warning that the import is unused. It is also possible to select subclasses +that have not been imported by giving the full import path: + +.. code:: bash + + python main.py fit --model my_code.models.Model1 diff --git a/src/pytorch_lightning/utilities/cli.py b/src/pytorch_lightning/utilities/cli.py index 6a2b32432e..d9386c4dd7 100644 --- a/src/pytorch_lightning/utilities/cli.py +++ b/src/pytorch_lightning/utilities/cli.py @@ -51,8 +51,23 @@ else: locals()["Namespace"] = object -class _Registry(dict): - def __call__(self, cls: Type, key: Optional[str] = None, override: bool = False) -> Type: +_deprecate_registry_message = ( + "`LightningCLI`'s registries were deprecated in v1.7 and will be removed " + "in v1.9. Now any imported subclass is automatically available by name in " + "`LightningCLI` without any need to explicitly register it." +) + +_deprecate_auto_registry_message = ( + "`LightningCLI.auto_registry` parameter was deprecated in v1.7 and will be removed " + "in v1.9. Now any imported subclass is automatically available by name in " + "`LightningCLI` without any need to explicitly register it." +) + + +class _Registry(dict): # Remove in v1.9 + def __call__( + self, cls: Type, key: Optional[str] = None, override: bool = False, show_deprecation: bool = True + ) -> Type: """Registers a class mapped to a name. Args: @@ -67,12 +82,16 @@ class _Registry(dict): if key not in self or override: self[key] = cls + + self._deprecation(show_deprecation) return cls - def register_classes(self, module: ModuleType, base_cls: Type, override: bool = False) -> None: + def register_classes( + self, module: ModuleType, base_cls: Type, override: bool = False, show_deprecation: bool = True + ) -> None: """This function is an utility to register all classes from a module.""" for cls in self.get_members(module, base_cls): - self(cls=cls, override=override) + self(cls=cls, override=override, show_deprecation=show_deprecation) @staticmethod def get_members(module: ModuleType, base_cls: Type) -> Generator[Type, None, None]: @@ -85,16 +104,23 @@ class _Registry(dict): @property def names(self) -> List[str]: """Returns the registered names.""" + self._deprecation() return list(self.keys()) @property def classes(self) -> Tuple[Type, ...]: """Returns the registered classes.""" + self._deprecation() return tuple(self.values()) def __str__(self) -> str: return f"Registered objects: {self.names}" + def _deprecation(self, show_deprecation: bool = True) -> None: + if show_deprecation and not getattr(self, "deprecation_shown", False): + rank_zero_deprecation(_deprecate_registry_message) + self.deprecation_shown = True + OPTIMIZER_REGISTRY = _Registry() LR_SCHEDULER_REGISTRY = _Registry() @@ -116,29 +142,32 @@ LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, ReduceLROnPl LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[ReduceLROnPlateau]] -def _populate_registries(subclasses: bool) -> None: +def _populate_registries(subclasses: bool) -> None: # Remove in v1.9 if subclasses: + rank_zero_deprecation(_deprecate_auto_registry_message) # this will register any subclasses from all loaded modules including userland for cls in get_all_subclasses(torch.optim.Optimizer): - OPTIMIZER_REGISTRY(cls) + OPTIMIZER_REGISTRY(cls, show_deprecation=False) for cls in get_all_subclasses(torch.optim.lr_scheduler._LRScheduler): - LR_SCHEDULER_REGISTRY(cls) + LR_SCHEDULER_REGISTRY(cls, show_deprecation=False) for cls in get_all_subclasses(pl.Callback): - CALLBACK_REGISTRY(cls) + CALLBACK_REGISTRY(cls, show_deprecation=False) for cls in get_all_subclasses(pl.LightningModule): - MODEL_REGISTRY(cls) + MODEL_REGISTRY(cls, show_deprecation=False) for cls in get_all_subclasses(pl.LightningDataModule): - DATAMODULE_REGISTRY(cls) + DATAMODULE_REGISTRY(cls, show_deprecation=False) for cls in get_all_subclasses(pl.loggers.Logger): - LOGGER_REGISTRY(cls) + LOGGER_REGISTRY(cls, show_deprecation=False) else: # manually register torch's subclasses and our subclasses - OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer) - LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) - CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.Callback) - LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.Logger) + OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer, show_deprecation=False) + LR_SCHEDULER_REGISTRY.register_classes( + torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler, show_deprecation=False + ) + CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.Callback, show_deprecation=False) + LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.Logger, show_deprecation=False) # `ReduceLROnPlateau` does not subclass `_LRScheduler` - LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau) + LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau, show_deprecation=False) class LightningArgumentParser(ArgumentParser): @@ -211,14 +240,14 @@ class LightningArgumentParser(ArgumentParser): def add_optimizer_args( self, - optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]], + optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]] = (Optimizer,), nested_key: str = "optimizer", link_to: str = "AUTOMATIC", ) -> None: """Adds arguments from an optimizer class to a nested key of the parser. Args: - optimizer_class: Any subclass of :class:`torch.optim.Optimizer`. + optimizer_class: Any subclass of :class:`torch.optim.Optimizer`. Use tuple to allow subclasses. nested_key: Name of the nested namespace to store arguments. link_to: Dot notation of a parser key to set arguments or AUTOMATIC. """ @@ -235,14 +264,15 @@ class LightningArgumentParser(ArgumentParser): def add_lr_scheduler_args( self, - lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]], + lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]] = LRSchedulerTypeTuple, nested_key: str = "lr_scheduler", link_to: str = "AUTOMATIC", ) -> None: """Adds arguments from a learning rate scheduler class to a nested key of the parser. Args: - lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``. + lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``. Use + tuple to allow subclasses. nested_key: Name of the nested namespace to store arguments. link_to: Dot notation of a parser key to set arguments or AUTOMATIC. """ diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py index 18aed8c7e1..74d509bd8d 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py @@ -19,7 +19,14 @@ import pytest import pytorch_lightning.loggers.base as logger_base from pytorch_lightning import Trainer from pytorch_lightning.core.module import LightningModule -from pytorch_lightning.utilities.cli import LightningCLI +from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.utilities.cli import ( + _deprecate_auto_registry_message, + _deprecate_registry_message, + CALLBACK_REGISTRY, + LightningCLI, + SaveConfigCallback, +) from pytorch_lightning.utilities.rank_zero import rank_zero_only @@ -133,3 +140,17 @@ def test_deprecated_dataloader_reset(): trainer = Trainer() with pytest.deprecated_call(match="reset_train_val_dataloaders` has been deprecated in v1.7"): trainer.reset_train_val_dataloaders() + + +def test_lightningCLI_registries_register(): + with pytest.deprecated_call(match=_deprecate_registry_message): + + @CALLBACK_REGISTRY + class CustomCallback(SaveConfigCallback): + pass + + +def test_lightningCLI_registries_register_automatically(): + with pytest.deprecated_call(match=_deprecate_auto_registry_message): + with mock.patch("sys.argv", ["any.py"]): + LightningCLI(BoringModel, run=False, auto_registry=True) diff --git a/tests/tests_pytorch/utilities/test_cli.py b/tests/tests_pytorch/utilities/test_cli.py index 4049e09af8..1dfab76484 100644 --- a/tests/tests_pytorch/utilities/test_cli.py +++ b/tests/tests_pytorch/utilities/test_cli.py @@ -34,22 +34,15 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, seed_everything, Trainer from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel -from pytorch_lightning.loggers import Logger, TensorBoardLogger +from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.cli import ( - _populate_registries, - CALLBACK_REGISTRY, - DATAMODULE_REGISTRY, instantiate_class, LightningArgumentParser, LightningCLI, - LOGGER_REGISTRY, - LR_SCHEDULER_REGISTRY, LRSchedulerTypeTuple, - MODEL_REGISTRY, - OPTIMIZER_REGISTRY, SaveConfigCallback, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -914,72 +907,6 @@ def test_lightning_cli_run(): assert isinstance(cli.model, LightningModule) -@pytest.fixture(autouse=True) -def clear_registries(): - # since the registries are global, it's good to clear them after each test to avoid unwanted interactions - yield - OPTIMIZER_REGISTRY.clear() - LR_SCHEDULER_REGISTRY.clear() - CALLBACK_REGISTRY.clear() - MODEL_REGISTRY.clear() - DATAMODULE_REGISTRY.clear() - LOGGER_REGISTRY.clear() - - -def test_registries(): - # the registries are global so this is only necessary when this test is run standalone - _populate_registries(False) - - @OPTIMIZER_REGISTRY - class CustomAdam(torch.optim.Adam): - pass - - @LR_SCHEDULER_REGISTRY - class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR): - pass - - @CALLBACK_REGISTRY - class CustomCallback(Callback): - pass - - @LOGGER_REGISTRY - class CustomLogger(Logger): - pass - - assert "SGD" in OPTIMIZER_REGISTRY.names - assert "RMSprop" in OPTIMIZER_REGISTRY.names - assert "CustomAdam" in OPTIMIZER_REGISTRY.names - - assert "CosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names - assert "CosineAnnealingWarmRestarts" in LR_SCHEDULER_REGISTRY.names - assert "CustomCosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names - assert "ReduceLROnPlateau" in LR_SCHEDULER_REGISTRY.names - - assert "EarlyStopping" in CALLBACK_REGISTRY.names - assert "CustomCallback" in CALLBACK_REGISTRY.names - - class Foo: - ... - - OPTIMIZER_REGISTRY(Foo, key="SGD") # not overridden by default - assert OPTIMIZER_REGISTRY["SGD"] is torch.optim.SGD - OPTIMIZER_REGISTRY(Foo, key="SGD", override=True) - assert OPTIMIZER_REGISTRY["SGD"] is Foo - - # test `_Registry.__call__` returns the class - assert isinstance(CustomCallback(), CustomCallback) - - assert "WandbLogger" in LOGGER_REGISTRY - assert "CustomLogger" in LOGGER_REGISTRY - - -def test_registries_register_automatically(): - assert "SaveConfigCallback" not in CALLBACK_REGISTRY - with mock.patch("sys.argv", ["any.py"]): - LightningCLI(BoringModel, run=False, auto_registry=True) - assert "SaveConfigCallback" in CALLBACK_REGISTRY - - class TestModel(BoringModel): def __init__(self, foo, bar=5): super().__init__() @@ -1137,11 +1064,11 @@ def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload super().__init__(*args, run=False) def add_arguments_to_parser(self, parser): - parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes, nested_key="opt1", link_to="model.opt1_config") + parser.add_optimizer_args(nested_key="opt1", link_to="model.opt1_config") parser.add_optimizer_args( (torch.optim.ASGD, torch.optim.SGD), nested_key="opt2", link_to="model.opt2_config" ) - parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes, link_to="model.sch_config") + parser.add_lr_scheduler_args(link_to="model.sch_config") parser.add_argument("--something", type=str, nargs="+") class TestModel(BoringModel):