Easier configurability of callbacks that should always be present in LightningCLI (#7964)

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
Mauricio Villegas 2021-06-16 02:03:37 +02:00 committed by GitHub
parent 78a14a3f56
commit 0004216f2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 80 additions and 11 deletions

View File

@ -83,6 +83,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added LightningCLI support for argument links applied on instantiation ([#7895](https://github.com/PyTorchLightning/pytorch-lightning/pull/7895))
- Added LightningCLI support for configurable callbacks that should always be present ([#7964](https://github.com/PyTorchLightning/pytorch-lightning/pull/7964))
- Added DeepSpeed Infinity Support, and updated to DeepSpeed 0.4.0 ([#7234](https://github.com/PyTorchLightning/pytorch-lightning/pull/7234))

View File

@ -92,6 +92,8 @@ practice to create a configuration file and provide this to the tool. A way to d
nano config.yaml
# Run training using created configuration
python trainer.py --config config.yaml
# The config JSON can also be passed directly
python trainer.py --config '{trainer: {fast_dev_run: True}}'
The instantiation of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class takes care of parsing command line
and config file options, instantiating the classes, setting up a callback to save the config in the log directory and
@ -376,6 +378,47 @@ Note that the config object :code:`self.config` is a dictionary whose keys are g
has the same structure as the yaml format described previously. This means for instance that the parameters used for
instantiating the trainer class can be found in :code:`self.config['trainer']`.
.. tip::
Have a look at the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class API reference to learn about other
methods that can be extended to customize a CLI.
Configurable callbacks
~~~~~~~~~~~~~~~~~~~~~~
As explained previously, any callback can be added by including it in the config via :code:`class_path` and
:code:`init_args` entries. However, there are other cases in which a callback should always be present and be
configurable. This can be implemented as follows:
.. testcode::
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.cli import LightningCLI
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_lightning_class_args(EarlyStopping, 'my_early_stopping')
parser.set_defaults({'my_early_stopping.patience': 5})
cli = MyLightningCLI(MyModel)
To change the configuration of the :code:`EarlyStopping` in the config it would be:
.. code-block:: yaml
model:
...
trainer:
...
my_early_stopping:
patience: 5
Argument linking
~~~~~~~~~~~~~~~~
Another case in which it might be desired to extend :class:`~pytorch_lightning.utilities.cli.LightningCLI` is that the
model and data module depend on a common parameter. For example in some cases both classes require to know the
:code:`batch_size`. It is a burden and error prone giving the same value twice in a config file. To avoid this the
@ -427,8 +470,3 @@ Instantiation links are used to automatically determine the order of instantiati
The linking of arguments can be used for more complex cases. For example to derive a value via a function that takes
multiple settings as input. For more details have a look at the API of `link_arguments
<https://jsonargparse.readthedocs.io/en/stable/#jsonargparse.core.ArgumentParser.link_arguments>`_.
.. tip::
Have a look at the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class API reference to learn about other
methods that can be extended to customize a CLI.

View File

@ -13,7 +13,7 @@
# limitations under the License.
import os
from argparse import Namespace
from typing import Any, Dict, Optional, Type, Union
from typing import Any, Dict, List, Optional, Type, Union
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
@ -33,7 +33,7 @@ else:
class LightningArgumentParser(ArgumentParser):
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning"""
def __init__(self, *args, parse_as_dict: bool = True, **kwargs) -> None:
def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None:
"""Initialize argument parser that supports configuration file input
For full details of accepted arguments see `ArgumentParser.__init__
@ -48,22 +48,25 @@ class LightningArgumentParser(ArgumentParser):
self.add_argument(
'--config', action=ActionConfigFile, help='Path to a configuration file in json or yaml format.'
)
self.callback_keys: List[str] = []
def add_lightning_class_args(
self,
lightning_class: Union[Type[Trainer], Type[LightningModule], Type[LightningDataModule]],
lightning_class: Union[Type[Trainer], Type[LightningModule], Type[LightningDataModule], Type[Callback]],
nested_key: str,
subclass_mode: bool = False
) -> None:
) -> List[str]:
"""
Adds arguments from a lightning class to a nested key of the parser
Args:
lightning_class: Any subclass of {Trainer,LightningModule,LightningDataModule}.
lightning_class: Any subclass of {Trainer, LightningModule, LightningDataModule, Callback}.
nested_key: Name of the nested namespace to store arguments.
subclass_mode: Whether allow any subclass of the given class.
"""
assert issubclass(lightning_class, (Trainer, LightningModule, LightningDataModule))
assert issubclass(lightning_class, (Trainer, LightningModule, LightningDataModule, Callback))
if issubclass(lightning_class, Callback):
self.callback_keys.append(nested_key)
if subclass_mode:
return self.add_subclass_arguments(lightning_class, nested_key, required=True)
return self.add_class_arguments(
@ -235,6 +238,8 @@ class LightningCLI:
"""Instantiates the trainer using self.config_init['trainer']"""
if self.config_init['trainer'].get('callbacks') is None:
self.config_init['trainer']['callbacks'] = []
callbacks = [self.config_init[c] for c in self.parser.callback_keys]
self.config_init['trainer']['callbacks'].extend(callbacks)
if 'callbacks' in self.trainer_defaults:
if isinstance(self.trainer_defaults['callbacks'], list):
self.config_init['trainer']['callbacks'].extend(self.trainer_defaults['callbacks'])

View File

@ -182,6 +182,8 @@ ignore_errors = True
# todo: add proper typing to this module...
[mypy-pytorch_lightning.utilities.*]
ignore_errors = True
[mypy-pytorch_lightning.utilities.cli]
ignore_errors = False
# todo: add proper typing to this module...
[mypy-pl_examples.*]

View File

@ -289,6 +289,27 @@ def test_lightning_cli_args_callbacks(tmpdir):
assert cli.trainer.ran_asserts
def test_lightning_cli_configurable_callbacks(tmpdir):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_lightning_class_args(LearningRateMonitor, 'learning_rate_monitor')
cli_args = [
f'--trainer.default_root_dir={tmpdir}',
'--trainer.max_epochs=1',
'--learning_rate_monitor.logging_interval=epoch',
]
with mock.patch('sys.argv', ['any.py'] + cli_args):
cli = MyLightningCLI(BoringModel)
callback = [c for c in cli.trainer.callbacks if isinstance(c, LearningRateMonitor)]
assert len(callback) == 1
assert callback[0].logging_interval == 'epoch'
def test_lightning_cli_args_cluster_environments(tmpdir):
plugins = [dict(class_path='pytorch_lightning.plugins.environments.SLURMEnvironment')]