From 0004216f2f90c39afdb938eff41ec73458f858f3 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Wed, 16 Jun 2021 02:03:37 +0200 Subject: [PATCH] Easier configurability of callbacks that should always be present in LightningCLI (#7964) Co-authored-by: Ethan Harris Co-authored-by: Carlos Mocholi --- CHANGELOG.md | 3 ++ docs/source/common/lightning_cli.rst | 48 +++++++++++++++++++++++++--- pytorch_lightning/utilities/cli.py | 17 ++++++---- setup.cfg | 2 ++ tests/utilities/test_cli.py | 21 ++++++++++++ 5 files changed, 80 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ddaf4288a0..f5ca6e5b9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -83,6 +83,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added LightningCLI support for argument links applied on instantiation ([#7895](https://github.com/PyTorchLightning/pytorch-lightning/pull/7895)) +- Added LightningCLI support for configurable callbacks that should always be present ([#7964](https://github.com/PyTorchLightning/pytorch-lightning/pull/7964)) + + - Added DeepSpeed Infinity Support, and updated to DeepSpeed 0.4.0 ([#7234](https://github.com/PyTorchLightning/pytorch-lightning/pull/7234)) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 949f85f58e..1fd3132443 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -92,6 +92,8 @@ practice to create a configuration file and provide this to the tool. A way to d nano config.yaml # Run training using created configuration python trainer.py --config config.yaml + # The config JSON can also be passed directly + python trainer.py --config '{trainer: {fast_dev_run: True}}' The instantiation of the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class takes care of parsing command line and config file options, instantiating the classes, setting up a callback to save the config in the log directory and @@ -376,6 +378,47 @@ Note that the config object :code:`self.config` is a dictionary whose keys are g has the same structure as the yaml format described previously. This means for instance that the parameters used for instantiating the trainer class can be found in :code:`self.config['trainer']`. +.. tip:: + + Have a look at the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class API reference to learn about other + methods that can be extended to customize a CLI. + + +Configurable callbacks +~~~~~~~~~~~~~~~~~~~~~~ + +As explained previously, any callback can be added by including it in the config via :code:`class_path` and +:code:`init_args` entries. However, there are other cases in which a callback should always be present and be +configurable. This can be implemented as follows: + +.. testcode:: + + from pytorch_lightning.callbacks import EarlyStopping + from pytorch_lightning.utilities.cli import LightningCLI + + class MyLightningCLI(LightningCLI): + + def add_arguments_to_parser(self, parser): + parser.add_lightning_class_args(EarlyStopping, 'my_early_stopping') + parser.set_defaults({'my_early_stopping.patience': 5}) + + cli = MyLightningCLI(MyModel) + +To change the configuration of the :code:`EarlyStopping` in the config it would be: + +.. code-block:: yaml + + model: + ... + trainer: + ... + my_early_stopping: + patience: 5 + + +Argument linking +~~~~~~~~~~~~~~~~ + Another case in which it might be desired to extend :class:`~pytorch_lightning.utilities.cli.LightningCLI` is that the model and data module depend on a common parameter. For example in some cases both classes require to know the :code:`batch_size`. It is a burden and error prone giving the same value twice in a config file. To avoid this the @@ -427,8 +470,3 @@ Instantiation links are used to automatically determine the order of instantiati The linking of arguments can be used for more complex cases. For example to derive a value via a function that takes multiple settings as input. For more details have a look at the API of `link_arguments `_. - -.. tip:: - - Have a look at the :class:`~pytorch_lightning.utilities.cli.LightningCLI` class API reference to learn about other - methods that can be extended to customize a CLI. diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index aed9de3355..1f1788393b 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -13,7 +13,7 @@ # limitations under the License. import os from argparse import Namespace -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type, Union from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.datamodule import LightningDataModule @@ -33,7 +33,7 @@ else: class LightningArgumentParser(ArgumentParser): """Extension of jsonargparse's ArgumentParser for pytorch-lightning""" - def __init__(self, *args, parse_as_dict: bool = True, **kwargs) -> None: + def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None: """Initialize argument parser that supports configuration file input For full details of accepted arguments see `ArgumentParser.__init__ @@ -48,22 +48,25 @@ class LightningArgumentParser(ArgumentParser): self.add_argument( '--config', action=ActionConfigFile, help='Path to a configuration file in json or yaml format.' ) + self.callback_keys: List[str] = [] def add_lightning_class_args( self, - lightning_class: Union[Type[Trainer], Type[LightningModule], Type[LightningDataModule]], + lightning_class: Union[Type[Trainer], Type[LightningModule], Type[LightningDataModule], Type[Callback]], nested_key: str, subclass_mode: bool = False - ) -> None: + ) -> List[str]: """ Adds arguments from a lightning class to a nested key of the parser Args: - lightning_class: Any subclass of {Trainer,LightningModule,LightningDataModule}. + lightning_class: Any subclass of {Trainer, LightningModule, LightningDataModule, Callback}. nested_key: Name of the nested namespace to store arguments. subclass_mode: Whether allow any subclass of the given class. """ - assert issubclass(lightning_class, (Trainer, LightningModule, LightningDataModule)) + assert issubclass(lightning_class, (Trainer, LightningModule, LightningDataModule, Callback)) + if issubclass(lightning_class, Callback): + self.callback_keys.append(nested_key) if subclass_mode: return self.add_subclass_arguments(lightning_class, nested_key, required=True) return self.add_class_arguments( @@ -235,6 +238,8 @@ class LightningCLI: """Instantiates the trainer using self.config_init['trainer']""" if self.config_init['trainer'].get('callbacks') is None: self.config_init['trainer']['callbacks'] = [] + callbacks = [self.config_init[c] for c in self.parser.callback_keys] + self.config_init['trainer']['callbacks'].extend(callbacks) if 'callbacks' in self.trainer_defaults: if isinstance(self.trainer_defaults['callbacks'], list): self.config_init['trainer']['callbacks'].extend(self.trainer_defaults['callbacks']) diff --git a/setup.cfg b/setup.cfg index b90c0663c5..594ebcb594 100644 --- a/setup.cfg +++ b/setup.cfg @@ -182,6 +182,8 @@ ignore_errors = True # todo: add proper typing to this module... [mypy-pytorch_lightning.utilities.*] ignore_errors = True +[mypy-pytorch_lightning.utilities.cli] +ignore_errors = False # todo: add proper typing to this module... [mypy-pl_examples.*] diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 40ddbf9b16..458726662d 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -289,6 +289,27 @@ def test_lightning_cli_args_callbacks(tmpdir): assert cli.trainer.ran_asserts +def test_lightning_cli_configurable_callbacks(tmpdir): + + class MyLightningCLI(LightningCLI): + + def add_arguments_to_parser(self, parser): + parser.add_lightning_class_args(LearningRateMonitor, 'learning_rate_monitor') + + cli_args = [ + f'--trainer.default_root_dir={tmpdir}', + '--trainer.max_epochs=1', + '--learning_rate_monitor.logging_interval=epoch', + ] + + with mock.patch('sys.argv', ['any.py'] + cli_args): + cli = MyLightningCLI(BoringModel) + + callback = [c for c in cli.trainer.callbacks if isinstance(c, LearningRateMonitor)] + assert len(callback) == 1 + assert callback[0].logging_interval == 'epoch' + + def test_lightning_cli_args_cluster_environments(tmpdir): plugins = [dict(class_path='pytorch_lightning.plugins.environments.SLURMEnvironment')]