[CLI] Shorthand notation to instantiate callbacks [3/3] (#8815)
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
bbcb977851
commit
1bb5fccb71
|
@ -19,4 +19,4 @@ jobs:
|
|||
run: |
|
||||
grep mypy requirements/test.txt | xargs -0 pip install
|
||||
pip list
|
||||
- run: mypy
|
||||
- run: mypy --install-types --non-interactive
|
||||
|
|
|
@ -57,6 +57,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
* Automatically register all optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
|
||||
* Allow registering custom optimizers and learning rate schedulers without subclassing the CLI ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
|
||||
* Support shorthand notation to instantiate optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
|
||||
* Support passing lists of callbacks via command line ([#8815](https://github.com/PyTorchLightning/pytorch-lightning/pull/8815))
|
||||
|
||||
|
||||
- Fault-tolerant training:
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
from unittest import mock
|
||||
from typing import List
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import LightningModule, LightningDataModule, Trainer
|
||||
from pytorch_lightning import LightningModule, LightningDataModule, Trainer, Callback
|
||||
|
||||
|
||||
class NoFitTrainer(Trainer):
|
||||
|
@ -371,6 +371,59 @@ Similar to the callbacks, any arguments in :class:`~pytorch_lightning.trainer.tr
|
|||
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class can be configured
|
||||
the same way using :code:`class_path` and :code:`init_args`.
|
||||
|
||||
For callbacks in particular, Lightning simplifies the command line so that only
|
||||
the :class:`~pytorch_lightning.callbacks.Callback` name is required.
|
||||
The argument's order matters and the user needs to pass the arguments in the following way.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python ... \
|
||||
--trainer.callbacks={CALLBACK_1_NAME} \
|
||||
--trainer.callbacks.{CALLBACK_1_ARGS_1}=... \
|
||||
--trainer.callbacks.{CALLBACK_1_ARGS_2}=... \
|
||||
...
|
||||
--trainer.callbacks={CALLBACK_N_NAME} \
|
||||
--trainer.callbacks.{CALLBACK_N_ARGS_1}=... \
|
||||
...
|
||||
|
||||
Here is an example:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python ... \
|
||||
--trainer.callbacks=EarlyStopping \
|
||||
--trainer.callbacks.patience=5 \
|
||||
--trainer.callbacks=LearningRateMonitor \
|
||||
--trainer.callbacks.logging_interval=epoch
|
||||
|
||||
Lightning provides a mechanism for you to add your own callbacks and benefit from the command line simplification
|
||||
as described above:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY
|
||||
|
||||
|
||||
@CALLBACK_REGISTRY
|
||||
class CustomCallback(Callback):
|
||||
...
|
||||
|
||||
|
||||
cli = LightningCLI(...)
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python ... --trainer.callbacks=CustomCallback ...
|
||||
|
||||
This callback will be included in the generated config:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
trainer:
|
||||
callbacks:
|
||||
- class_path: your_class_path.CustomCallback
|
||||
init_args:
|
||||
...
|
||||
|
||||
Multiple models and/or datasets
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
@ -517,9 +570,10 @@ instantiating the trainer class can be found in :code:`self.config['fit']['train
|
|||
Configurable callbacks
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
As explained previously, any callback can be added by including it in the config via :code:`class_path` and
|
||||
:code:`init_args` entries. However, there are other cases in which a callback should always be present and be
|
||||
configurable. This can be implemented as follows:
|
||||
As explained previously, any Lightning callback can be added by passing it through command line or
|
||||
including it in the config via :code:`class_path` and :code:`init_args` entries.
|
||||
However, there are other cases in which a callback should always be present and be configurable.
|
||||
This can be implemented as follows:
|
||||
|
||||
.. testcode::
|
||||
|
||||
|
|
|
@ -20,8 +20,10 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
|||
from unittest import mock
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from torch.optim import Optimizer
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
|
||||
from pytorch_lightning.utilities import _JSONARGPARSE_AVAILABLE, rank_zero_warn, warnings
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
|
@ -83,12 +85,15 @@ OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer)
|
|||
LR_SCHEDULER_REGISTRY = _Registry()
|
||||
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)
|
||||
|
||||
CALLBACK_REGISTRY = _Registry()
|
||||
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.callbacks.Callback)
|
||||
|
||||
|
||||
class LightningArgumentParser(ArgumentParser):
|
||||
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""
|
||||
|
||||
# use class attribute because `parse_args` is only called on the main parser
|
||||
_choices: Dict[str, Tuple[Type, ...]] = {}
|
||||
_choices: Dict[str, Tuple[Tuple[Type, ...], bool]] = {}
|
||||
|
||||
def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None:
|
||||
"""Initialize argument parser that supports configuration file input.
|
||||
|
@ -202,23 +207,35 @@ class LightningArgumentParser(ArgumentParser):
|
|||
|
||||
def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
argv = sys.argv
|
||||
for k, classes in self._choices.items():
|
||||
for k, v in self._choices.items():
|
||||
if not any(arg.startswith(f"--{k}") for arg in argv):
|
||||
# the key wasn't passed - maybe defined in a config, maybe it's optional
|
||||
continue
|
||||
argv = self._convert_argv_issue_84(classes, k, argv)
|
||||
classes, is_list = v
|
||||
# knowing whether the argument is a list type automatically would be too complex
|
||||
if is_list:
|
||||
argv = self._convert_argv_issue_85(classes, k, argv)
|
||||
else:
|
||||
argv = self._convert_argv_issue_84(classes, k, argv)
|
||||
self._choices.clear()
|
||||
with mock.patch("sys.argv", argv):
|
||||
return super().parse_args(*args, **kwargs)
|
||||
|
||||
def set_choices(self, nested_key: str, classes: Tuple[Type, ...]) -> None:
|
||||
self._choices[nested_key] = classes
|
||||
def set_choices(self, nested_key: str, classes: Tuple[Type, ...], is_list: bool = False) -> None:
|
||||
"""Adds support for shorthand notation for a particular nested key.
|
||||
|
||||
Args:
|
||||
nested_key: The key whose choices will be set.
|
||||
classes: A tuple of classes to choose from.
|
||||
is_list: Whether the argument is a ``List[object]`` type.
|
||||
"""
|
||||
self._choices[nested_key] = (classes, is_list)
|
||||
|
||||
@staticmethod
|
||||
def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]:
|
||||
"""Placeholder for https://github.com/omni-us/jsonargparse/issues/84.
|
||||
|
||||
This should be removed once implemented.
|
||||
Adds support for shorthand notation for ``object`` arguments.
|
||||
"""
|
||||
passed_args, clean_argv = {}, []
|
||||
argv_key = f"--{nested_key}"
|
||||
|
@ -259,6 +276,64 @@ class LightningArgumentParser(ArgumentParser):
|
|||
raise ValueError(f"Could not generate a config for {repr(argv_class)}")
|
||||
return clean_argv + [argv_key, config]
|
||||
|
||||
@staticmethod
|
||||
def _convert_argv_issue_85(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]:
|
||||
"""Placeholder for https://github.com/omni-us/jsonargparse/issues/85.
|
||||
|
||||
Adds support for shorthand notation for ``List[object]`` arguments.
|
||||
"""
|
||||
passed_args, clean_argv = [], []
|
||||
passed_configs = {}
|
||||
argv_key = f"--{nested_key}"
|
||||
# get the argv args for this nested key
|
||||
i = 0
|
||||
while i < len(argv):
|
||||
arg = argv[i]
|
||||
if arg.startswith(argv_key):
|
||||
if "=" in arg:
|
||||
key, value = arg.split("=")
|
||||
else:
|
||||
key = arg
|
||||
i += 1
|
||||
value = argv[i]
|
||||
if "class_path" in value:
|
||||
# the user passed a config as a dict
|
||||
passed_configs[key] = yaml.safe_load(value)
|
||||
else:
|
||||
passed_args.append((key, value))
|
||||
else:
|
||||
clean_argv.append(arg)
|
||||
i += 1
|
||||
# generate the associated config file
|
||||
config = []
|
||||
i, n = 0, len(passed_args)
|
||||
while i < n - 1:
|
||||
ki, vi = passed_args[i]
|
||||
# convert class name to class path
|
||||
for cls in classes:
|
||||
if cls.__name__ == vi:
|
||||
cls_type = cls
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Could not generate a config for {repr(vi)}")
|
||||
config.append(_global_add_class_path(cls_type))
|
||||
# get any init args
|
||||
j = i + 1 # in case the j-loop doesn't run
|
||||
for j in range(i + 1, n):
|
||||
kj, vj = passed_args[j]
|
||||
if ki == kj:
|
||||
break
|
||||
if kj.startswith(ki):
|
||||
init_arg_name = kj.split(".")[-1]
|
||||
config[-1]["init_args"][init_arg_name] = vj
|
||||
i = j
|
||||
# update at the end to preserve the order
|
||||
for k, v in passed_configs.items():
|
||||
config.extend(v)
|
||||
if not config:
|
||||
return clean_argv
|
||||
return clean_argv + [argv_key, str(config)]
|
||||
|
||||
|
||||
class SaveConfigCallback(Callback):
|
||||
"""Saves a LightningCLI config to the log_dir when training starts.
|
||||
|
@ -430,6 +505,7 @@ class LightningCLI:
|
|||
def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
||||
"""Adds arguments from the core classes to the parser."""
|
||||
parser.add_lightning_class_args(self.trainer_class, "trainer")
|
||||
parser.set_choices("trainer.callbacks", CALLBACK_REGISTRY.classes, is_list=True)
|
||||
trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"}
|
||||
parser.set_defaults(trainer_defaults)
|
||||
parser.add_lightning_class_args(self.model_class, "model", subclass_mode=self.subclass_mode_model)
|
||||
|
|
|
@ -34,6 +34,7 @@ from pytorch_lightning.plugins.environments import SLURMEnvironment
|
|||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.cli import (
|
||||
CALLBACK_REGISTRY,
|
||||
instantiate_class,
|
||||
LightningArgumentParser,
|
||||
LightningCLI,
|
||||
|
@ -861,6 +862,11 @@ class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
|
|||
pass
|
||||
|
||||
|
||||
@CALLBACK_REGISTRY
|
||||
class CustomCallback(Callback):
|
||||
pass
|
||||
|
||||
|
||||
def test_registries(tmpdir):
|
||||
assert "SGD" in OPTIMIZER_REGISTRY.names
|
||||
assert "RMSprop" in OPTIMIZER_REGISTRY.names
|
||||
|
@ -870,23 +876,41 @@ def test_registries(tmpdir):
|
|||
assert "CosineAnnealingWarmRestarts" in LR_SCHEDULER_REGISTRY.names
|
||||
assert "CustomCosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names
|
||||
|
||||
assert "EarlyStopping" in CALLBACK_REGISTRY.names
|
||||
assert "CustomCallback" in CALLBACK_REGISTRY.names
|
||||
|
||||
with pytest.raises(MisconfigurationException, match="is already present in the registry"):
|
||||
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer)
|
||||
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer, override=True)
|
||||
|
||||
|
||||
def test_registries_resolution():
|
||||
@pytest.mark.parametrize("use_class_path_callbacks", [False, True])
|
||||
def test_registries_resolution(use_class_path_callbacks):
|
||||
"""This test validates registries are used when simplified command line are being used."""
|
||||
cli_args = [
|
||||
"--optimizer",
|
||||
"Adam",
|
||||
"--optimizer.lr",
|
||||
"0.0001",
|
||||
"--trainer.callbacks=LearningRateMonitor",
|
||||
"--trainer.callbacks.logging_interval=epoch",
|
||||
"--trainer.callbacks.log_momentum=True",
|
||||
"--trainer.callbacks=ModelCheckpoint",
|
||||
"--trainer.callbacks.monitor=loss",
|
||||
"--lr_scheduler",
|
||||
"StepLR",
|
||||
"--lr_scheduler.step_size=50",
|
||||
]
|
||||
|
||||
extras = []
|
||||
if use_class_path_callbacks:
|
||||
callbacks = [
|
||||
{"class_path": "pytorch_lightning.callbacks.Callback"},
|
||||
{"class_path": "pytorch_lightning.callbacks.Callback", "init_args": {}},
|
||||
]
|
||||
cli_args += [f"--trainer.callbacks={json.dumps(callbacks)}"]
|
||||
extras = [Callback, Callback]
|
||||
|
||||
with mock.patch("sys.argv", ["any.py"] + cli_args):
|
||||
cli = LightningCLI(BoringModel, run=False)
|
||||
|
||||
|
@ -895,6 +919,80 @@ def test_registries_resolution():
|
|||
assert optimizers[0].param_groups[0]["lr"] == 0.0001
|
||||
assert lr_scheduler[0].step_size == 50
|
||||
|
||||
callback_types = [type(c) for c in cli.trainer.callbacks]
|
||||
expected = [LearningRateMonitor, SaveConfigCallback, ModelCheckpoint] + extras
|
||||
assert all(t in callback_types for t in expected)
|
||||
|
||||
|
||||
def test_argv_transformation_noop():
|
||||
base = ["any.py", "--trainer.max_epochs=1"]
|
||||
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", base)
|
||||
assert argv == base
|
||||
|
||||
|
||||
def test_argv_transformation_single_callback():
|
||||
base = ["any.py", "--trainer.max_epochs=1"]
|
||||
input = base + ["--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=val_loss"]
|
||||
callbacks = [
|
||||
{
|
||||
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
|
||||
"init_args": {"monitor": "val_loss"},
|
||||
}
|
||||
]
|
||||
expected = base + ["--trainer.callbacks", str(callbacks)]
|
||||
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
|
||||
assert argv == expected
|
||||
|
||||
|
||||
def test_argv_transformation_multiple_callbacks():
|
||||
base = ["any.py", "--trainer.max_epochs=1"]
|
||||
input = base + [
|
||||
"--trainer.callbacks=ModelCheckpoint",
|
||||
"--trainer.callbacks.monitor=val_loss",
|
||||
"--trainer.callbacks=ModelCheckpoint",
|
||||
"--trainer.callbacks.monitor=val_acc",
|
||||
]
|
||||
callbacks = [
|
||||
{
|
||||
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
|
||||
"init_args": {"monitor": "val_loss"},
|
||||
},
|
||||
{
|
||||
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
|
||||
"init_args": {"monitor": "val_acc"},
|
||||
},
|
||||
]
|
||||
expected = base + ["--trainer.callbacks", str(callbacks)]
|
||||
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
|
||||
assert argv == expected
|
||||
|
||||
|
||||
def test_argv_transformation_multiple_callbacks_with_config():
|
||||
base = ["any.py", "--trainer.max_epochs=1"]
|
||||
nested_key = "trainer.callbacks"
|
||||
input = base + [
|
||||
f"--{nested_key}=ModelCheckpoint",
|
||||
f"--{nested_key}.monitor=val_loss",
|
||||
f"--{nested_key}=ModelCheckpoint",
|
||||
f"--{nested_key}.monitor=val_acc",
|
||||
f"--{nested_key}=[{{'class_path': 'pytorch_lightning.callbacks.Callback'}}]",
|
||||
]
|
||||
callbacks = [
|
||||
{
|
||||
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
|
||||
"init_args": {"monitor": "val_loss"},
|
||||
},
|
||||
{
|
||||
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
|
||||
"init_args": {"monitor": "val_acc"},
|
||||
},
|
||||
{"class_path": "pytorch_lightning.callbacks.Callback"},
|
||||
]
|
||||
expected = base + ["--trainer.callbacks", str(callbacks)]
|
||||
nested_key = "trainer.callbacks"
|
||||
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, nested_key, input)
|
||||
assert argv == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["args", "expected", "nested_key", "registry"],
|
||||
|
|
Loading…
Reference in New Issue