Add `LightningCLI(auto_registry)` (#12108)

This commit is contained in:
Carlos Mocholí 2022-03-08 18:26:10 +01:00 committed by GitHub
parent cadcc67386
commit 8fa156948a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 119 additions and 48 deletions

View File

@ -57,6 +57,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `LightningCLI.configure_optimizers` to override the `configure_optimizers` return value ([#10860](https://github.com/PyTorchLightning/pytorch-lightning/pull/10860)) - Added `LightningCLI.configure_optimizers` to override the `configure_optimizers` return value ([#10860](https://github.com/PyTorchLightning/pytorch-lightning/pull/10860))
- Added `LightningCLI(auto_registry)` flag to register all subclasses of the registerable components automatically ([#12108](https://github.com/PyTorchLightning/pytorch-lightning/pull/12108))
- Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/pull/10700)) - Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/pull/10700))

View File

@ -345,6 +345,29 @@ This can be useful to implement custom logic without having to subclass the CLI,
and argument parsing capabilities. and argument parsing capabilities.
Subclass registration
^^^^^^^^^^^^^^^^^^^^^
To use shorthand notation, the options need to be registered beforehand. This can be easily done with:
.. code-block::
LightningCLI(auto_registry=True) # False by default
which will register all subclasses of :class:`torch.optim.Optimizer`, :class:`torch.optim.lr_scheduler._LRScheduler`,
:class:`~pytorch_lightning.core.lightning.LightningModule`,
:class:`~pytorch_lightning.core.datamodule.LightningDataModule`, :class:`~pytorch_lightning.callbacks.Callback`, and
:class:`~pytorch_lightning.loggers.LightningLoggerBase` across all imported modules. This includes those in your own
code.
Alternatively, if this is left unset, only the subclasses defined in PyTorch's :class:`torch.optim.Optimizer`,
:class:`torch.optim.lr_scheduler._LRScheduler` and Lightning's :class:`~pytorch_lightning.callbacks.Callback` and
:class:`~pytorch_lightning.loggers.LightningLoggerBase` subclassess will be registered.
In subsequent sections, we will go over adding specific classes to specific registries as well as how to use
shorthand notation.
Trainer Callbacks and arguments with class type Trainer Callbacks and arguments with class type
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

View File

@ -30,6 +30,7 @@ from pytorch_lightning import Callback, LightningDataModule, LightningModule, se
from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _JSONARGPARSE_AVAILABLE from pytorch_lightning.utilities.imports import _JSONARGPARSE_AVAILABLE
from pytorch_lightning.utilities.meta import get_all_subclasses
from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_warn from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_warn
from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple, LRSchedulerTypeUnion from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple, LRSchedulerTypeUnion
@ -58,9 +59,8 @@ class _Registry(dict):
elif not isinstance(key, str): elif not isinstance(key, str):
raise TypeError(f"`key` must be a str, found {key}") raise TypeError(f"`key` must be a str, found {key}")
if key in self and not override: if key not in self or override:
raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.") self[key] = cls
self[key] = cls
return cls return cls
def register_classes(self, module: ModuleType, base_cls: Type, override: bool = False) -> None: def register_classes(self, module: ModuleType, base_cls: Type, override: bool = False) -> None:
@ -91,10 +91,11 @@ class _Registry(dict):
OPTIMIZER_REGISTRY = _Registry() OPTIMIZER_REGISTRY = _Registry()
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer)
LR_SCHEDULER_REGISTRY = _Registry() LR_SCHEDULER_REGISTRY = _Registry()
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) CALLBACK_REGISTRY = _Registry()
MODEL_REGISTRY = _Registry()
DATAMODULE_REGISTRY = _Registry()
LOGGER_REGISTRY = _Registry()
class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
@ -103,17 +104,29 @@ class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
self.monitor = monitor self.monitor = monitor
LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau) def _populate_registries(subclasses: bool) -> None:
if subclasses:
CALLBACK_REGISTRY = _Registry() # this will register any subclasses from all loaded modules including userland
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.callbacks.Callback) for cls in get_all_subclasses(torch.optim.Optimizer):
OPTIMIZER_REGISTRY(cls)
MODEL_REGISTRY = _Registry() for cls in get_all_subclasses(torch.optim.lr_scheduler._LRScheduler):
LR_SCHEDULER_REGISTRY(cls)
DATAMODULE_REGISTRY = _Registry() for cls in get_all_subclasses(pl.Callback):
CALLBACK_REGISTRY(cls)
LOGGER_REGISTRY = _Registry() for cls in get_all_subclasses(pl.LightningModule):
LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.LightningLoggerBase) MODEL_REGISTRY(cls)
for cls in get_all_subclasses(pl.LightningDataModule):
DATAMODULE_REGISTRY(cls)
for cls in get_all_subclasses(pl.loggers.LightningLoggerBase):
LOGGER_REGISTRY(cls)
else:
# manually register torch's subclasses and our subclasses
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer)
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.Callback)
LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.LightningLoggerBase)
# `ReduceLROnPlateau` does not subclass `_LRScheduler`
LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau)
class LightningArgumentParser(ArgumentParser): class LightningArgumentParser(ArgumentParser):
@ -465,6 +478,7 @@ class LightningCLI:
subclass_mode_model: bool = False, subclass_mode_model: bool = False,
subclass_mode_data: bool = False, subclass_mode_data: bool = False,
run: bool = True, run: bool = True,
auto_registry: bool = False,
) -> None: ) -> None:
"""Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which
are called / instantiated using a parsed configuration file and / or command line args. are called / instantiated using a parsed configuration file and / or command line args.
@ -508,6 +522,7 @@ class LightningCLI:
of the given class. of the given class.
run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer` run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
method. If set to ``False``, the trainer and model classes will be instantiated only. method. If set to ``False``, the trainer and model classes will be instantiated only.
auto_registry: Whether to automatically fill up the registries with all defined subclasses.
""" """
self.save_config_callback = save_config_callback self.save_config_callback = save_config_callback
self.save_config_filename = save_config_filename self.save_config_filename = save_config_filename
@ -527,6 +542,8 @@ class LightningCLI:
self._datamodule_class = datamodule_class or LightningDataModule self._datamodule_class = datamodule_class or LightningDataModule
self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data
_populate_registries(auto_registry)
main_kwargs, subparser_kwargs = self._setup_parser_kwargs( main_kwargs, subparser_kwargs = self._setup_parser_kwargs(
parser_kwargs or {}, # type: ignore # github.com/python/mypy/issues/6463 parser_kwargs or {}, # type: ignore # github.com/python/mypy/issues/6463
{"description": description, "env_prefix": env_prefix, "default_env": env_parse}, {"description": description, "env_prefix": env_prefix, "default_env": env_parse},

View File

@ -147,7 +147,7 @@ else:
# https://stackoverflow.com/a/63851681/9201239 # https://stackoverflow.com/a/63851681/9201239
def get_all_subclasses(cls: Type[nn.Module]) -> Set[nn.Module]: def get_all_subclasses(cls: Type) -> Set[Type]:
subclass_list = [] subclass_list = []
def recurse(cl): def recurse(cl):

View File

@ -38,6 +38,7 @@ from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.cli import ( from pytorch_lightning.utilities.cli import (
_populate_registries,
CALLBACK_REGISTRY, CALLBACK_REGISTRY,
DATAMODULE_REGISTRY, DATAMODULE_REGISTRY,
instantiate_class, instantiate_class,
@ -880,27 +881,38 @@ def test_lightning_cli_run():
assert isinstance(cli.model, LightningModule) assert isinstance(cli.model, LightningModule)
@OPTIMIZER_REGISTRY @pytest.fixture(autouse=True)
class CustomAdam(torch.optim.Adam): def clear_registries():
pass # since the registries are global, it's good to clear them after each test to avoid unwanted interactions
yield
OPTIMIZER_REGISTRY.clear()
@LR_SCHEDULER_REGISTRY LR_SCHEDULER_REGISTRY.clear()
class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR): CALLBACK_REGISTRY.clear()
pass MODEL_REGISTRY.clear()
DATAMODULE_REGISTRY.clear()
LOGGER_REGISTRY.clear()
@CALLBACK_REGISTRY
class CustomCallback(Callback):
pass
@LOGGER_REGISTRY
class CustomLogger(LightningLoggerBase):
pass
def test_registries(): def test_registries():
# the registries are global so this is only necessary when this test is run standalone
_populate_registries(False)
@OPTIMIZER_REGISTRY
class CustomAdam(torch.optim.Adam):
pass
@LR_SCHEDULER_REGISTRY
class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
pass
@CALLBACK_REGISTRY
class CustomCallback(Callback):
pass
@LOGGER_REGISTRY
class CustomLogger(LightningLoggerBase):
pass
assert "SGD" in OPTIMIZER_REGISTRY.names assert "SGD" in OPTIMIZER_REGISTRY.names
assert "RMSprop" in OPTIMIZER_REGISTRY.names assert "RMSprop" in OPTIMIZER_REGISTRY.names
assert "CustomAdam" in OPTIMIZER_REGISTRY.names assert "CustomAdam" in OPTIMIZER_REGISTRY.names
@ -913,9 +925,13 @@ def test_registries():
assert "EarlyStopping" in CALLBACK_REGISTRY.names assert "EarlyStopping" in CALLBACK_REGISTRY.names
assert "CustomCallback" in CALLBACK_REGISTRY.names assert "CustomCallback" in CALLBACK_REGISTRY.names
with pytest.raises(MisconfigurationException, match="is already present in the registry"): class Foo:
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer) ...
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer, override=True)
OPTIMIZER_REGISTRY(Foo, key="SGD") # not overridden by default
assert OPTIMIZER_REGISTRY["SGD"] is torch.optim.SGD
OPTIMIZER_REGISTRY(Foo, key="SGD", override=True)
assert OPTIMIZER_REGISTRY["SGD"] is Foo
# test `_Registry.__call__` returns the class # test `_Registry.__call__` returns the class
assert isinstance(CustomCallback(), CustomCallback) assert isinstance(CustomCallback(), CustomCallback)
@ -924,7 +940,13 @@ def test_registries():
assert "CustomLogger" in LOGGER_REGISTRY assert "CustomLogger" in LOGGER_REGISTRY
@MODEL_REGISTRY def test_registries_register_automatically():
assert "SaveConfigCallback" not in CALLBACK_REGISTRY
with mock.patch("sys.argv", ["any.py"]):
LightningCLI(BoringModel, run=False, auto_registry=True)
assert "SaveConfigCallback" in CALLBACK_REGISTRY
class TestModel(BoringModel): class TestModel(BoringModel):
def __init__(self, foo, bar=5): def __init__(self, foo, bar=5):
super().__init__() super().__init__()
@ -932,10 +954,10 @@ class TestModel(BoringModel):
self.bar = bar self.bar = bar
MODEL_REGISTRY(cls=BoringModel)
def test_lightning_cli_model_choices(): def test_lightning_cli_model_choices():
MODEL_REGISTRY(cls=TestModel)
MODEL_REGISTRY(cls=BoringModel)
with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch( with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch(
"pytorch_lightning.Trainer._fit_impl" "pytorch_lightning.Trainer._fit_impl"
) as run: ) as run:
@ -950,7 +972,6 @@ def test_lightning_cli_model_choices():
assert cli.model.bar == 5 assert cli.model.bar == 5
@DATAMODULE_REGISTRY
class MyDataModule(BoringDataModule): class MyDataModule(BoringDataModule):
def __init__(self, foo, bar=5): def __init__(self, foo, bar=5):
super().__init__() super().__init__()
@ -958,10 +979,11 @@ class MyDataModule(BoringDataModule):
self.bar = bar self.bar = bar
DATAMODULE_REGISTRY(cls=BoringDataModule)
def test_lightning_cli_datamodule_choices(): def test_lightning_cli_datamodule_choices():
MODEL_REGISTRY(cls=BoringModel)
DATAMODULE_REGISTRY(cls=MyDataModule)
DATAMODULE_REGISTRY(cls=BoringDataModule)
# with set model # with set model
with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch( with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch(
"pytorch_lightning.Trainer._fit_impl" "pytorch_lightning.Trainer._fit_impl"
@ -998,7 +1020,7 @@ def test_lightning_cli_datamodule_choices():
assert not hasattr(cli.parser.groups["data"], "group_class") assert not hasattr(cli.parser.groups["data"], "group_class")
with mock.patch("sys.argv", ["any.py"]), mock.patch.dict(DATAMODULE_REGISTRY, clear=True): with mock.patch("sys.argv", ["any.py"]), mock.patch.dict(DATAMODULE_REGISTRY, clear=True):
cli = LightningCLI(BoringModel, run=False) cli = LightningCLI(BoringModel, run=False, auto_registry=False)
# no registered classes so not added automatically # no registered classes so not added automatically
assert "data" not in cli.parser.groups assert "data" not in cli.parser.groups
assert len(DATAMODULE_REGISTRY) # check state was not modified assert len(DATAMODULE_REGISTRY) # check state was not modified
@ -1011,6 +1033,8 @@ def test_lightning_cli_datamodule_choices():
@pytest.mark.parametrize("use_class_path_callbacks", [False, True]) @pytest.mark.parametrize("use_class_path_callbacks", [False, True])
def test_registries_resolution(use_class_path_callbacks): def test_registries_resolution(use_class_path_callbacks):
MODEL_REGISTRY(cls=BoringModel)
"""This test validates registries are used when simplified command line are being used.""" """This test validates registries are used when simplified command line are being used."""
cli_args = [ cli_args = [
"--optimizer", "--optimizer",
@ -1067,6 +1091,7 @@ def test_argv_transformation_single_callback():
} }
] ]
expected = base + ["--trainer.callbacks", str(callbacks)] expected = base + ["--trainer.callbacks", str(callbacks)]
_populate_registries(False)
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input) argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
assert argv == expected assert argv == expected
@ -1090,6 +1115,7 @@ def test_argv_transformation_multiple_callbacks():
}, },
] ]
expected = base + ["--trainer.callbacks", str(callbacks)] expected = base + ["--trainer.callbacks", str(callbacks)]
_populate_registries(False)
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input) argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
assert argv == expected assert argv == expected
@ -1117,6 +1143,7 @@ def test_argv_transformation_multiple_callbacks_with_config():
] ]
expected = base + ["--trainer.callbacks", str(callbacks)] expected = base + ["--trainer.callbacks", str(callbacks)]
nested_key = "trainer.callbacks" nested_key = "trainer.callbacks"
_populate_registries(False)
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, nested_key, input) argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, nested_key, input)
assert argv == expected assert argv == expected
@ -1153,6 +1180,7 @@ def test_argv_transformation_multiple_callbacks_with_config():
def test_argv_transformations_with_optimizers_and_lr_schedulers(args, expected, nested_key, registry): def test_argv_transformations_with_optimizers_and_lr_schedulers(args, expected, nested_key, registry):
base = ["any.py", "--trainer.max_epochs=1"] base = ["any.py", "--trainer.max_epochs=1"]
argv = base + args argv = base + args
_populate_registries(False)
new_argv = LightningArgumentParser._convert_argv_issue_84(registry.classes, nested_key, argv) new_argv = LightningArgumentParser._convert_argv_issue_84(registry.classes, nested_key, argv)
assert new_argv == base + [f"--{nested_key}", str(expected)] assert new_argv == base + [f"--{nested_key}", str(expected)]