Easier configurability of callbacks that should always be present in LightningCLI (#7964)
Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
This commit is contained in:
parent
78a14a3f56
commit
0004216f2f
|
@ -83,6 +83,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added LightningCLI support for argument links applied on instantiation ([#7895](https://github.com/PyTorchLightning/pytorch-lightning/pull/7895))
|
||||
|
||||
|
||||
- Added LightningCLI support for configurable callbacks that should always be present ([#7964](https://github.com/PyTorchLightning/pytorch-lightning/pull/7964))
|
||||
|
||||
|
||||
- Added DeepSpeed Infinity Support, and updated to DeepSpeed 0.4.0 ([#7234](https://github.com/PyTorchLightning/pytorch-lightning/pull/7234))
|
||||
|
||||
|
||||
|
|
|
@ -92,6 +92,8 @@ practice to create a configuration file and provide this to the tool. A way to d
|
|||
nano config.yaml
|
||||
# Run training using created configuration
|
||||
python trainer.py --config config.yaml
|
||||
# The config JSON can also be passed directly
|
||||
python trainer.py --config '{trainer: {fast_dev_run: True}}'
|
||||
|
||||
The instantiation of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class takes care of parsing command line
|
||||
and config file options, instantiating the classes, setting up a callback to save the config in the log directory and
|
||||
|
@ -376,6 +378,47 @@ Note that the config object :code:`self.config` is a dictionary whose keys are g
|
|||
has the same structure as the yaml format described previously. This means for instance that the parameters used for
|
||||
instantiating the trainer class can be found in :code:`self.config['trainer']`.
|
||||
|
||||
.. tip::
|
||||
|
||||
Have a look at the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class API reference to learn about other
|
||||
methods that can be extended to customize a CLI.
|
||||
|
||||
|
||||
Configurable callbacks
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
As explained previously, any callback can be added by including it in the config via :code:`class_path` and
|
||||
:code:`init_args` entries. However, there are other cases in which a callback should always be present and be
|
||||
configurable. This can be implemented as follows:
|
||||
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from pytorch_lightning.utilities.cli import LightningCLI
|
||||
|
||||
class MyLightningCLI(LightningCLI):
|
||||
|
||||
def add_arguments_to_parser(self, parser):
|
||||
parser.add_lightning_class_args(EarlyStopping, 'my_early_stopping')
|
||||
parser.set_defaults({'my_early_stopping.patience': 5})
|
||||
|
||||
cli = MyLightningCLI(MyModel)
|
||||
|
||||
To change the configuration of the :code:`EarlyStopping` in the config it would be:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
model:
|
||||
...
|
||||
trainer:
|
||||
...
|
||||
my_early_stopping:
|
||||
patience: 5
|
||||
|
||||
|
||||
Argument linking
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
Another case in which it might be desired to extend :class:`~pytorch_lightning.utilities.cli.LightningCLI` is that the
|
||||
model and data module depend on a common parameter. For example in some cases both classes require to know the
|
||||
:code:`batch_size`. It is a burden and error prone giving the same value twice in a config file. To avoid this the
|
||||
|
@ -427,8 +470,3 @@ Instantiation links are used to automatically determine the order of instantiati
|
|||
The linking of arguments can be used for more complex cases. For example to derive a value via a function that takes
|
||||
multiple settings as input. For more details have a look at the API of `link_arguments
|
||||
<https://jsonargparse.readthedocs.io/en/stable/#jsonargparse.core.ArgumentParser.link_arguments>`_.
|
||||
|
||||
.. tip::
|
||||
|
||||
Have a look at the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class API reference to learn about other
|
||||
methods that can be extended to customize a CLI.
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
from argparse import Namespace
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
|
@ -33,7 +33,7 @@ else:
|
|||
class LightningArgumentParser(ArgumentParser):
|
||||
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning"""
|
||||
|
||||
def __init__(self, *args, parse_as_dict: bool = True, **kwargs) -> None:
|
||||
def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None:
|
||||
"""Initialize argument parser that supports configuration file input
|
||||
|
||||
For full details of accepted arguments see `ArgumentParser.__init__
|
||||
|
@ -48,22 +48,25 @@ class LightningArgumentParser(ArgumentParser):
|
|||
self.add_argument(
|
||||
'--config', action=ActionConfigFile, help='Path to a configuration file in json or yaml format.'
|
||||
)
|
||||
self.callback_keys: List[str] = []
|
||||
|
||||
def add_lightning_class_args(
|
||||
self,
|
||||
lightning_class: Union[Type[Trainer], Type[LightningModule], Type[LightningDataModule]],
|
||||
lightning_class: Union[Type[Trainer], Type[LightningModule], Type[LightningDataModule], Type[Callback]],
|
||||
nested_key: str,
|
||||
subclass_mode: bool = False
|
||||
) -> None:
|
||||
) -> List[str]:
|
||||
"""
|
||||
Adds arguments from a lightning class to a nested key of the parser
|
||||
|
||||
Args:
|
||||
lightning_class: Any subclass of {Trainer,LightningModule,LightningDataModule}.
|
||||
lightning_class: Any subclass of {Trainer, LightningModule, LightningDataModule, Callback}.
|
||||
nested_key: Name of the nested namespace to store arguments.
|
||||
subclass_mode: Whether allow any subclass of the given class.
|
||||
"""
|
||||
assert issubclass(lightning_class, (Trainer, LightningModule, LightningDataModule))
|
||||
assert issubclass(lightning_class, (Trainer, LightningModule, LightningDataModule, Callback))
|
||||
if issubclass(lightning_class, Callback):
|
||||
self.callback_keys.append(nested_key)
|
||||
if subclass_mode:
|
||||
return self.add_subclass_arguments(lightning_class, nested_key, required=True)
|
||||
return self.add_class_arguments(
|
||||
|
@ -235,6 +238,8 @@ class LightningCLI:
|
|||
"""Instantiates the trainer using self.config_init['trainer']"""
|
||||
if self.config_init['trainer'].get('callbacks') is None:
|
||||
self.config_init['trainer']['callbacks'] = []
|
||||
callbacks = [self.config_init[c] for c in self.parser.callback_keys]
|
||||
self.config_init['trainer']['callbacks'].extend(callbacks)
|
||||
if 'callbacks' in self.trainer_defaults:
|
||||
if isinstance(self.trainer_defaults['callbacks'], list):
|
||||
self.config_init['trainer']['callbacks'].extend(self.trainer_defaults['callbacks'])
|
||||
|
|
|
@ -182,6 +182,8 @@ ignore_errors = True
|
|||
# todo: add proper typing to this module...
|
||||
[mypy-pytorch_lightning.utilities.*]
|
||||
ignore_errors = True
|
||||
[mypy-pytorch_lightning.utilities.cli]
|
||||
ignore_errors = False
|
||||
|
||||
# todo: add proper typing to this module...
|
||||
[mypy-pl_examples.*]
|
||||
|
|
|
@ -289,6 +289,27 @@ def test_lightning_cli_args_callbacks(tmpdir):
|
|||
assert cli.trainer.ran_asserts
|
||||
|
||||
|
||||
def test_lightning_cli_configurable_callbacks(tmpdir):
|
||||
|
||||
class MyLightningCLI(LightningCLI):
|
||||
|
||||
def add_arguments_to_parser(self, parser):
|
||||
parser.add_lightning_class_args(LearningRateMonitor, 'learning_rate_monitor')
|
||||
|
||||
cli_args = [
|
||||
f'--trainer.default_root_dir={tmpdir}',
|
||||
'--trainer.max_epochs=1',
|
||||
'--learning_rate_monitor.logging_interval=epoch',
|
||||
]
|
||||
|
||||
with mock.patch('sys.argv', ['any.py'] + cli_args):
|
||||
cli = MyLightningCLI(BoringModel)
|
||||
|
||||
callback = [c for c in cli.trainer.callbacks if isinstance(c, LearningRateMonitor)]
|
||||
assert len(callback) == 1
|
||||
assert callback[0].logging_interval == 'epoch'
|
||||
|
||||
|
||||
def test_lightning_cli_args_cluster_environments(tmpdir):
|
||||
plugins = [dict(class_path='pytorch_lightning.plugins.environments.SLURMEnvironment')]
|
||||
|
||||
|
|
Loading…
Reference in New Issue