Add `LightningCLI(auto_registry)` (#12108)
This commit is contained in:
parent
cadcc67386
commit
8fa156948a
|
@ -57,6 +57,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added `LightningCLI.configure_optimizers` to override the `configure_optimizers` return value ([#10860](https://github.com/PyTorchLightning/pytorch-lightning/pull/10860))
|
||||
|
||||
|
||||
- Added `LightningCLI(auto_registry)` flag to register all subclasses of the registerable components automatically ([#12108](https://github.com/PyTorchLightning/pytorch-lightning/pull/12108))
|
||||
|
||||
|
||||
- Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/pull/10700))
|
||||
|
||||
|
||||
|
|
|
@ -345,6 +345,29 @@ This can be useful to implement custom logic without having to subclass the CLI,
|
|||
and argument parsing capabilities.
|
||||
|
||||
|
||||
Subclass registration
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
To use shorthand notation, the options need to be registered beforehand. This can be easily done with:
|
||||
|
||||
.. code-block::
|
||||
|
||||
LightningCLI(auto_registry=True) # False by default
|
||||
|
||||
which will register all subclasses of :class:`torch.optim.Optimizer`, :class:`torch.optim.lr_scheduler._LRScheduler`,
|
||||
:class:`~pytorch_lightning.core.lightning.LightningModule`,
|
||||
:class:`~pytorch_lightning.core.datamodule.LightningDataModule`, :class:`~pytorch_lightning.callbacks.Callback`, and
|
||||
:class:`~pytorch_lightning.loggers.LightningLoggerBase` across all imported modules. This includes those in your own
|
||||
code.
|
||||
|
||||
Alternatively, if this is left unset, only the subclasses defined in PyTorch's :class:`torch.optim.Optimizer`,
|
||||
:class:`torch.optim.lr_scheduler._LRScheduler` and Lightning's :class:`~pytorch_lightning.callbacks.Callback` and
|
||||
:class:`~pytorch_lightning.loggers.LightningLoggerBase` subclassess will be registered.
|
||||
|
||||
In subsequent sections, we will go over adding specific classes to specific registries as well as how to use
|
||||
shorthand notation.
|
||||
|
||||
|
||||
Trainer Callbacks and arguments with class type
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -30,6 +30,7 @@ from pytorch_lightning import Callback, LightningDataModule, LightningModule, se
|
|||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _JSONARGPARSE_AVAILABLE
|
||||
from pytorch_lightning.utilities.meta import get_all_subclasses
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple, LRSchedulerTypeUnion
|
||||
|
@ -58,9 +59,8 @@ class _Registry(dict):
|
|||
elif not isinstance(key, str):
|
||||
raise TypeError(f"`key` must be a str, found {key}")
|
||||
|
||||
if key in self and not override:
|
||||
raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.")
|
||||
self[key] = cls
|
||||
if key not in self or override:
|
||||
self[key] = cls
|
||||
return cls
|
||||
|
||||
def register_classes(self, module: ModuleType, base_cls: Type, override: bool = False) -> None:
|
||||
|
@ -91,10 +91,11 @@ class _Registry(dict):
|
|||
|
||||
|
||||
OPTIMIZER_REGISTRY = _Registry()
|
||||
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer)
|
||||
|
||||
LR_SCHEDULER_REGISTRY = _Registry()
|
||||
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)
|
||||
CALLBACK_REGISTRY = _Registry()
|
||||
MODEL_REGISTRY = _Registry()
|
||||
DATAMODULE_REGISTRY = _Registry()
|
||||
LOGGER_REGISTRY = _Registry()
|
||||
|
||||
|
||||
class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
|
||||
|
@ -103,17 +104,29 @@ class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
|
|||
self.monitor = monitor
|
||||
|
||||
|
||||
LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau)
|
||||
|
||||
CALLBACK_REGISTRY = _Registry()
|
||||
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.callbacks.Callback)
|
||||
|
||||
MODEL_REGISTRY = _Registry()
|
||||
|
||||
DATAMODULE_REGISTRY = _Registry()
|
||||
|
||||
LOGGER_REGISTRY = _Registry()
|
||||
LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.LightningLoggerBase)
|
||||
def _populate_registries(subclasses: bool) -> None:
|
||||
if subclasses:
|
||||
# this will register any subclasses from all loaded modules including userland
|
||||
for cls in get_all_subclasses(torch.optim.Optimizer):
|
||||
OPTIMIZER_REGISTRY(cls)
|
||||
for cls in get_all_subclasses(torch.optim.lr_scheduler._LRScheduler):
|
||||
LR_SCHEDULER_REGISTRY(cls)
|
||||
for cls in get_all_subclasses(pl.Callback):
|
||||
CALLBACK_REGISTRY(cls)
|
||||
for cls in get_all_subclasses(pl.LightningModule):
|
||||
MODEL_REGISTRY(cls)
|
||||
for cls in get_all_subclasses(pl.LightningDataModule):
|
||||
DATAMODULE_REGISTRY(cls)
|
||||
for cls in get_all_subclasses(pl.loggers.LightningLoggerBase):
|
||||
LOGGER_REGISTRY(cls)
|
||||
else:
|
||||
# manually register torch's subclasses and our subclasses
|
||||
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer)
|
||||
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)
|
||||
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.Callback)
|
||||
LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.LightningLoggerBase)
|
||||
# `ReduceLROnPlateau` does not subclass `_LRScheduler`
|
||||
LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau)
|
||||
|
||||
|
||||
class LightningArgumentParser(ArgumentParser):
|
||||
|
@ -465,6 +478,7 @@ class LightningCLI:
|
|||
subclass_mode_model: bool = False,
|
||||
subclass_mode_data: bool = False,
|
||||
run: bool = True,
|
||||
auto_registry: bool = False,
|
||||
) -> None:
|
||||
"""Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which
|
||||
are called / instantiated using a parsed configuration file and / or command line args.
|
||||
|
@ -508,6 +522,7 @@ class LightningCLI:
|
|||
of the given class.
|
||||
run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
|
||||
method. If set to ``False``, the trainer and model classes will be instantiated only.
|
||||
auto_registry: Whether to automatically fill up the registries with all defined subclasses.
|
||||
"""
|
||||
self.save_config_callback = save_config_callback
|
||||
self.save_config_filename = save_config_filename
|
||||
|
@ -527,6 +542,8 @@ class LightningCLI:
|
|||
self._datamodule_class = datamodule_class or LightningDataModule
|
||||
self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data
|
||||
|
||||
_populate_registries(auto_registry)
|
||||
|
||||
main_kwargs, subparser_kwargs = self._setup_parser_kwargs(
|
||||
parser_kwargs or {}, # type: ignore # github.com/python/mypy/issues/6463
|
||||
{"description": description, "env_prefix": env_prefix, "default_env": env_parse},
|
||||
|
|
|
@ -147,7 +147,7 @@ else:
|
|||
|
||||
|
||||
# https://stackoverflow.com/a/63851681/9201239
|
||||
def get_all_subclasses(cls: Type[nn.Module]) -> Set[nn.Module]:
|
||||
def get_all_subclasses(cls: Type) -> Set[Type]:
|
||||
subclass_list = []
|
||||
|
||||
def recurse(cl):
|
||||
|
|
|
@ -38,6 +38,7 @@ from pytorch_lightning.plugins.environments import SLURMEnvironment
|
|||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.cli import (
|
||||
_populate_registries,
|
||||
CALLBACK_REGISTRY,
|
||||
DATAMODULE_REGISTRY,
|
||||
instantiate_class,
|
||||
|
@ -880,27 +881,38 @@ def test_lightning_cli_run():
|
|||
assert isinstance(cli.model, LightningModule)
|
||||
|
||||
|
||||
@OPTIMIZER_REGISTRY
|
||||
class CustomAdam(torch.optim.Adam):
|
||||
pass
|
||||
|
||||
|
||||
@LR_SCHEDULER_REGISTRY
|
||||
class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
|
||||
pass
|
||||
|
||||
|
||||
@CALLBACK_REGISTRY
|
||||
class CustomCallback(Callback):
|
||||
pass
|
||||
|
||||
|
||||
@LOGGER_REGISTRY
|
||||
class CustomLogger(LightningLoggerBase):
|
||||
pass
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_registries():
|
||||
# since the registries are global, it's good to clear them after each test to avoid unwanted interactions
|
||||
yield
|
||||
OPTIMIZER_REGISTRY.clear()
|
||||
LR_SCHEDULER_REGISTRY.clear()
|
||||
CALLBACK_REGISTRY.clear()
|
||||
MODEL_REGISTRY.clear()
|
||||
DATAMODULE_REGISTRY.clear()
|
||||
LOGGER_REGISTRY.clear()
|
||||
|
||||
|
||||
def test_registries():
|
||||
# the registries are global so this is only necessary when this test is run standalone
|
||||
_populate_registries(False)
|
||||
|
||||
@OPTIMIZER_REGISTRY
|
||||
class CustomAdam(torch.optim.Adam):
|
||||
pass
|
||||
|
||||
@LR_SCHEDULER_REGISTRY
|
||||
class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
|
||||
pass
|
||||
|
||||
@CALLBACK_REGISTRY
|
||||
class CustomCallback(Callback):
|
||||
pass
|
||||
|
||||
@LOGGER_REGISTRY
|
||||
class CustomLogger(LightningLoggerBase):
|
||||
pass
|
||||
|
||||
assert "SGD" in OPTIMIZER_REGISTRY.names
|
||||
assert "RMSprop" in OPTIMIZER_REGISTRY.names
|
||||
assert "CustomAdam" in OPTIMIZER_REGISTRY.names
|
||||
|
@ -913,9 +925,13 @@ def test_registries():
|
|||
assert "EarlyStopping" in CALLBACK_REGISTRY.names
|
||||
assert "CustomCallback" in CALLBACK_REGISTRY.names
|
||||
|
||||
with pytest.raises(MisconfigurationException, match="is already present in the registry"):
|
||||
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer)
|
||||
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer, override=True)
|
||||
class Foo:
|
||||
...
|
||||
|
||||
OPTIMIZER_REGISTRY(Foo, key="SGD") # not overridden by default
|
||||
assert OPTIMIZER_REGISTRY["SGD"] is torch.optim.SGD
|
||||
OPTIMIZER_REGISTRY(Foo, key="SGD", override=True)
|
||||
assert OPTIMIZER_REGISTRY["SGD"] is Foo
|
||||
|
||||
# test `_Registry.__call__` returns the class
|
||||
assert isinstance(CustomCallback(), CustomCallback)
|
||||
|
@ -924,7 +940,13 @@ def test_registries():
|
|||
assert "CustomLogger" in LOGGER_REGISTRY
|
||||
|
||||
|
||||
@MODEL_REGISTRY
|
||||
def test_registries_register_automatically():
|
||||
assert "SaveConfigCallback" not in CALLBACK_REGISTRY
|
||||
with mock.patch("sys.argv", ["any.py"]):
|
||||
LightningCLI(BoringModel, run=False, auto_registry=True)
|
||||
assert "SaveConfigCallback" in CALLBACK_REGISTRY
|
||||
|
||||
|
||||
class TestModel(BoringModel):
|
||||
def __init__(self, foo, bar=5):
|
||||
super().__init__()
|
||||
|
@ -932,10 +954,10 @@ class TestModel(BoringModel):
|
|||
self.bar = bar
|
||||
|
||||
|
||||
MODEL_REGISTRY(cls=BoringModel)
|
||||
|
||||
|
||||
def test_lightning_cli_model_choices():
|
||||
MODEL_REGISTRY(cls=TestModel)
|
||||
MODEL_REGISTRY(cls=BoringModel)
|
||||
|
||||
with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch(
|
||||
"pytorch_lightning.Trainer._fit_impl"
|
||||
) as run:
|
||||
|
@ -950,7 +972,6 @@ def test_lightning_cli_model_choices():
|
|||
assert cli.model.bar == 5
|
||||
|
||||
|
||||
@DATAMODULE_REGISTRY
|
||||
class MyDataModule(BoringDataModule):
|
||||
def __init__(self, foo, bar=5):
|
||||
super().__init__()
|
||||
|
@ -958,10 +979,11 @@ class MyDataModule(BoringDataModule):
|
|||
self.bar = bar
|
||||
|
||||
|
||||
DATAMODULE_REGISTRY(cls=BoringDataModule)
|
||||
|
||||
|
||||
def test_lightning_cli_datamodule_choices():
|
||||
MODEL_REGISTRY(cls=BoringModel)
|
||||
DATAMODULE_REGISTRY(cls=MyDataModule)
|
||||
DATAMODULE_REGISTRY(cls=BoringDataModule)
|
||||
|
||||
# with set model
|
||||
with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch(
|
||||
"pytorch_lightning.Trainer._fit_impl"
|
||||
|
@ -998,7 +1020,7 @@ def test_lightning_cli_datamodule_choices():
|
|||
assert not hasattr(cli.parser.groups["data"], "group_class")
|
||||
|
||||
with mock.patch("sys.argv", ["any.py"]), mock.patch.dict(DATAMODULE_REGISTRY, clear=True):
|
||||
cli = LightningCLI(BoringModel, run=False)
|
||||
cli = LightningCLI(BoringModel, run=False, auto_registry=False)
|
||||
# no registered classes so not added automatically
|
||||
assert "data" not in cli.parser.groups
|
||||
assert len(DATAMODULE_REGISTRY) # check state was not modified
|
||||
|
@ -1011,6 +1033,8 @@ def test_lightning_cli_datamodule_choices():
|
|||
|
||||
@pytest.mark.parametrize("use_class_path_callbacks", [False, True])
|
||||
def test_registries_resolution(use_class_path_callbacks):
|
||||
MODEL_REGISTRY(cls=BoringModel)
|
||||
|
||||
"""This test validates registries are used when simplified command line are being used."""
|
||||
cli_args = [
|
||||
"--optimizer",
|
||||
|
@ -1067,6 +1091,7 @@ def test_argv_transformation_single_callback():
|
|||
}
|
||||
]
|
||||
expected = base + ["--trainer.callbacks", str(callbacks)]
|
||||
_populate_registries(False)
|
||||
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
|
||||
assert argv == expected
|
||||
|
||||
|
@ -1090,6 +1115,7 @@ def test_argv_transformation_multiple_callbacks():
|
|||
},
|
||||
]
|
||||
expected = base + ["--trainer.callbacks", str(callbacks)]
|
||||
_populate_registries(False)
|
||||
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
|
||||
assert argv == expected
|
||||
|
||||
|
@ -1117,6 +1143,7 @@ def test_argv_transformation_multiple_callbacks_with_config():
|
|||
]
|
||||
expected = base + ["--trainer.callbacks", str(callbacks)]
|
||||
nested_key = "trainer.callbacks"
|
||||
_populate_registries(False)
|
||||
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, nested_key, input)
|
||||
assert argv == expected
|
||||
|
||||
|
@ -1153,6 +1180,7 @@ def test_argv_transformation_multiple_callbacks_with_config():
|
|||
def test_argv_transformations_with_optimizers_and_lr_schedulers(args, expected, nested_key, registry):
|
||||
base = ["any.py", "--trainer.max_epochs=1"]
|
||||
argv = base + args
|
||||
_populate_registries(False)
|
||||
new_argv = LightningArgumentParser._convert_argv_issue_84(registry.classes, nested_key, argv)
|
||||
assert new_argv == base + [f"--{nested_key}", str(expected)]
|
||||
|
||||
|
|
Loading…
Reference in New Issue