From 8fa156948a7ba6e3c8882ecaf3f0455e76119df2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 8 Mar 2022 18:26:10 +0100 Subject: [PATCH] Add `LightningCLI(auto_registry)` (#12108) --- CHANGELOG.md | 3 + docs/source/common/lightning_cli.rst | 23 ++++++++ pytorch_lightning/utilities/cli.py | 51 ++++++++++------ pytorch_lightning/utilities/meta.py | 2 +- tests/utilities/test_cli.py | 88 ++++++++++++++++++---------- 5 files changed, 119 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bbc65633bd..daf117eff1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `LightningCLI.configure_optimizers` to override the `configure_optimizers` return value ([#10860](https://github.com/PyTorchLightning/pytorch-lightning/pull/10860)) +- Added `LightningCLI(auto_registry)` flag to register all subclasses of the registerable components automatically ([#12108](https://github.com/PyTorchLightning/pytorch-lightning/pull/12108)) + + - Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/pull/10700)) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index ec77785d8a..7267606b05 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -345,6 +345,29 @@ This can be useful to implement custom logic without having to subclass the CLI, and argument parsing capabilities. +Subclass registration +^^^^^^^^^^^^^^^^^^^^^ + +To use shorthand notation, the options need to be registered beforehand. This can be easily done with: + +.. code-block:: + + LightningCLI(auto_registry=True) # False by default + +which will register all subclasses of :class:`torch.optim.Optimizer`, :class:`torch.optim.lr_scheduler._LRScheduler`, +:class:`~pytorch_lightning.core.lightning.LightningModule`, +:class:`~pytorch_lightning.core.datamodule.LightningDataModule`, :class:`~pytorch_lightning.callbacks.Callback`, and +:class:`~pytorch_lightning.loggers.LightningLoggerBase` across all imported modules. This includes those in your own +code. + +Alternatively, if this is left unset, only the subclasses defined in PyTorch's :class:`torch.optim.Optimizer`, +:class:`torch.optim.lr_scheduler._LRScheduler` and Lightning's :class:`~pytorch_lightning.callbacks.Callback` and +:class:`~pytorch_lightning.loggers.LightningLoggerBase` subclassess will be registered. + +In subsequent sections, we will go over adding specific classes to specific registries as well as how to use +shorthand notation. + + Trainer Callbacks and arguments with class type ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 976f1d58b8..2e8d629a88 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -30,6 +30,7 @@ from pytorch_lightning import Callback, LightningDataModule, LightningModule, se from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _JSONARGPARSE_AVAILABLE +from pytorch_lightning.utilities.meta import get_all_subclasses from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_warn from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple, LRSchedulerTypeUnion @@ -58,9 +59,8 @@ class _Registry(dict): elif not isinstance(key, str): raise TypeError(f"`key` must be a str, found {key}") - if key in self and not override: - raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.") - self[key] = cls + if key not in self or override: + self[key] = cls return cls def register_classes(self, module: ModuleType, base_cls: Type, override: bool = False) -> None: @@ -91,10 +91,11 @@ class _Registry(dict): OPTIMIZER_REGISTRY = _Registry() -OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer) - LR_SCHEDULER_REGISTRY = _Registry() -LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) +CALLBACK_REGISTRY = _Registry() +MODEL_REGISTRY = _Registry() +DATAMODULE_REGISTRY = _Registry() +LOGGER_REGISTRY = _Registry() class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): @@ -103,17 +104,29 @@ class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): self.monitor = monitor -LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau) - -CALLBACK_REGISTRY = _Registry() -CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.callbacks.Callback) - -MODEL_REGISTRY = _Registry() - -DATAMODULE_REGISTRY = _Registry() - -LOGGER_REGISTRY = _Registry() -LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.LightningLoggerBase) +def _populate_registries(subclasses: bool) -> None: + if subclasses: + # this will register any subclasses from all loaded modules including userland + for cls in get_all_subclasses(torch.optim.Optimizer): + OPTIMIZER_REGISTRY(cls) + for cls in get_all_subclasses(torch.optim.lr_scheduler._LRScheduler): + LR_SCHEDULER_REGISTRY(cls) + for cls in get_all_subclasses(pl.Callback): + CALLBACK_REGISTRY(cls) + for cls in get_all_subclasses(pl.LightningModule): + MODEL_REGISTRY(cls) + for cls in get_all_subclasses(pl.LightningDataModule): + DATAMODULE_REGISTRY(cls) + for cls in get_all_subclasses(pl.loggers.LightningLoggerBase): + LOGGER_REGISTRY(cls) + else: + # manually register torch's subclasses and our subclasses + OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer) + LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) + CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.Callback) + LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.LightningLoggerBase) + # `ReduceLROnPlateau` does not subclass `_LRScheduler` + LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau) class LightningArgumentParser(ArgumentParser): @@ -465,6 +478,7 @@ class LightningCLI: subclass_mode_model: bool = False, subclass_mode_data: bool = False, run: bool = True, + auto_registry: bool = False, ) -> None: """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are called / instantiated using a parsed configuration file and / or command line args. @@ -508,6 +522,7 @@ class LightningCLI: of the given class. run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer` method. If set to ``False``, the trainer and model classes will be instantiated only. + auto_registry: Whether to automatically fill up the registries with all defined subclasses. """ self.save_config_callback = save_config_callback self.save_config_filename = save_config_filename @@ -527,6 +542,8 @@ class LightningCLI: self._datamodule_class = datamodule_class or LightningDataModule self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data + _populate_registries(auto_registry) + main_kwargs, subparser_kwargs = self._setup_parser_kwargs( parser_kwargs or {}, # type: ignore # github.com/python/mypy/issues/6463 {"description": description, "env_prefix": env_prefix, "default_env": env_parse}, diff --git a/pytorch_lightning/utilities/meta.py b/pytorch_lightning/utilities/meta.py index 0b9b21193b..d14f111e87 100644 --- a/pytorch_lightning/utilities/meta.py +++ b/pytorch_lightning/utilities/meta.py @@ -147,7 +147,7 @@ else: # https://stackoverflow.com/a/63851681/9201239 -def get_all_subclasses(cls: Type[nn.Module]) -> Set[nn.Module]: +def get_all_subclasses(cls: Type) -> Set[Type]: subclass_list = [] def recurse(cl): diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 6d70567189..5da16737fc 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -38,6 +38,7 @@ from pytorch_lightning.plugins.environments import SLURMEnvironment from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.cli import ( + _populate_registries, CALLBACK_REGISTRY, DATAMODULE_REGISTRY, instantiate_class, @@ -880,27 +881,38 @@ def test_lightning_cli_run(): assert isinstance(cli.model, LightningModule) -@OPTIMIZER_REGISTRY -class CustomAdam(torch.optim.Adam): - pass - - -@LR_SCHEDULER_REGISTRY -class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR): - pass - - -@CALLBACK_REGISTRY -class CustomCallback(Callback): - pass - - -@LOGGER_REGISTRY -class CustomLogger(LightningLoggerBase): - pass +@pytest.fixture(autouse=True) +def clear_registries(): + # since the registries are global, it's good to clear them after each test to avoid unwanted interactions + yield + OPTIMIZER_REGISTRY.clear() + LR_SCHEDULER_REGISTRY.clear() + CALLBACK_REGISTRY.clear() + MODEL_REGISTRY.clear() + DATAMODULE_REGISTRY.clear() + LOGGER_REGISTRY.clear() def test_registries(): + # the registries are global so this is only necessary when this test is run standalone + _populate_registries(False) + + @OPTIMIZER_REGISTRY + class CustomAdam(torch.optim.Adam): + pass + + @LR_SCHEDULER_REGISTRY + class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR): + pass + + @CALLBACK_REGISTRY + class CustomCallback(Callback): + pass + + @LOGGER_REGISTRY + class CustomLogger(LightningLoggerBase): + pass + assert "SGD" in OPTIMIZER_REGISTRY.names assert "RMSprop" in OPTIMIZER_REGISTRY.names assert "CustomAdam" in OPTIMIZER_REGISTRY.names @@ -913,9 +925,13 @@ def test_registries(): assert "EarlyStopping" in CALLBACK_REGISTRY.names assert "CustomCallback" in CALLBACK_REGISTRY.names - with pytest.raises(MisconfigurationException, match="is already present in the registry"): - OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer) - OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer, override=True) + class Foo: + ... + + OPTIMIZER_REGISTRY(Foo, key="SGD") # not overridden by default + assert OPTIMIZER_REGISTRY["SGD"] is torch.optim.SGD + OPTIMIZER_REGISTRY(Foo, key="SGD", override=True) + assert OPTIMIZER_REGISTRY["SGD"] is Foo # test `_Registry.__call__` returns the class assert isinstance(CustomCallback(), CustomCallback) @@ -924,7 +940,13 @@ def test_registries(): assert "CustomLogger" in LOGGER_REGISTRY -@MODEL_REGISTRY +def test_registries_register_automatically(): + assert "SaveConfigCallback" not in CALLBACK_REGISTRY + with mock.patch("sys.argv", ["any.py"]): + LightningCLI(BoringModel, run=False, auto_registry=True) + assert "SaveConfigCallback" in CALLBACK_REGISTRY + + class TestModel(BoringModel): def __init__(self, foo, bar=5): super().__init__() @@ -932,10 +954,10 @@ class TestModel(BoringModel): self.bar = bar -MODEL_REGISTRY(cls=BoringModel) - - def test_lightning_cli_model_choices(): + MODEL_REGISTRY(cls=TestModel) + MODEL_REGISTRY(cls=BoringModel) + with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch( "pytorch_lightning.Trainer._fit_impl" ) as run: @@ -950,7 +972,6 @@ def test_lightning_cli_model_choices(): assert cli.model.bar == 5 -@DATAMODULE_REGISTRY class MyDataModule(BoringDataModule): def __init__(self, foo, bar=5): super().__init__() @@ -958,10 +979,11 @@ class MyDataModule(BoringDataModule): self.bar = bar -DATAMODULE_REGISTRY(cls=BoringDataModule) - - def test_lightning_cli_datamodule_choices(): + MODEL_REGISTRY(cls=BoringModel) + DATAMODULE_REGISTRY(cls=MyDataModule) + DATAMODULE_REGISTRY(cls=BoringDataModule) + # with set model with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch( "pytorch_lightning.Trainer._fit_impl" @@ -998,7 +1020,7 @@ def test_lightning_cli_datamodule_choices(): assert not hasattr(cli.parser.groups["data"], "group_class") with mock.patch("sys.argv", ["any.py"]), mock.patch.dict(DATAMODULE_REGISTRY, clear=True): - cli = LightningCLI(BoringModel, run=False) + cli = LightningCLI(BoringModel, run=False, auto_registry=False) # no registered classes so not added automatically assert "data" not in cli.parser.groups assert len(DATAMODULE_REGISTRY) # check state was not modified @@ -1011,6 +1033,8 @@ def test_lightning_cli_datamodule_choices(): @pytest.mark.parametrize("use_class_path_callbacks", [False, True]) def test_registries_resolution(use_class_path_callbacks): + MODEL_REGISTRY(cls=BoringModel) + """This test validates registries are used when simplified command line are being used.""" cli_args = [ "--optimizer", @@ -1067,6 +1091,7 @@ def test_argv_transformation_single_callback(): } ] expected = base + ["--trainer.callbacks", str(callbacks)] + _populate_registries(False) argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input) assert argv == expected @@ -1090,6 +1115,7 @@ def test_argv_transformation_multiple_callbacks(): }, ] expected = base + ["--trainer.callbacks", str(callbacks)] + _populate_registries(False) argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input) assert argv == expected @@ -1117,6 +1143,7 @@ def test_argv_transformation_multiple_callbacks_with_config(): ] expected = base + ["--trainer.callbacks", str(callbacks)] nested_key = "trainer.callbacks" + _populate_registries(False) argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, nested_key, input) assert argv == expected @@ -1153,6 +1180,7 @@ def test_argv_transformation_multiple_callbacks_with_config(): def test_argv_transformations_with_optimizers_and_lr_schedulers(args, expected, nested_key, registry): base = ["any.py", "--trainer.max_epochs=1"] argv = base + args + _populate_registries(False) new_argv = LightningArgumentParser._convert_argv_issue_84(registry.classes, nested_key, argv) assert new_argv == base + [f"--{nested_key}", str(expected)]