diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index a5cdf9f93e..9d44d65db7 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -33,6 +33,9 @@ MyModelBaseClass = MyModel MyDataModuleBaseClass = MyDataModule + EncoderBaseClass = MyModel + DecoderBaseClass = MyModel + mock_argv = mock.patch("sys.argv", ["any.py"]) mock_argv.start() @@ -116,7 +119,7 @@ The start of a possible implementation of :class:`MyModel` including the recomme docstring could be the one below. Note that by using type hints and docstrings there is no need to duplicate this information to define its configurable arguments. -.. code-block:: python +.. testcode:: class MyModel(LightningModule): @@ -131,7 +134,8 @@ information to define its configurable arguments. encoder_layers: Number of layers for the encoder decoder_layers: Number of layers for each decoder block """ - ... + super().__init__() + self.save_hyperparameters() With this model class, the help of the trainer tool would look as follows: @@ -258,7 +262,67 @@ A possible config file could be as follows: ... Only model classes that are a subclass of :code:`MyModelBaseClass` would be allowed, and similarly only subclasses of -:code:`MyDataModuleBaseClass`. +:code:`MyDataModuleBaseClass`. If as base classes :class:`~pytorch_lightning.core.lightning.LightningModule` and +:class:`~pytorch_lightning.core.datamodule.LightningDataModule` are given, then the tool would allow any lightning +module and data module. + +.. tip:: + + Note that with the subclass modes the :code:`--help` option does not show information for a specific subclass. To + get help for a subclass the options :code:`--model.help` and :code:`--data.help` can be used, followed by the + desired class path. Similarly :code:`--print_config` does not include the settings for a particular subclass. To + include them the class path should be given before the :code:`--print_config` option. Examples for both help and + print config are: + + .. code-block:: bash + + $ python trainer.py --model.help mycode.mymodels.MyModel + $ python trainer.py --model mycode.mymodels.MyModel --print_config + + +Models with multiple submodules +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Many use cases require to have several modules each with its own configurable options. One possible way to handle this +with LightningCLI is to implement a single module having as init parameters each of the submodules. Since the init +parameters have as type a class, then in the configuration these would be specified with :code:`class_path` and +:code:`init_args` entries. For instance a model could be implemented as: + +.. testcode:: + + class MyMainModel(LightningModule): + + def __init__( + self, + encoder: EncoderBaseClass, + decoder: DecoderBaseClass + ): + """Example encoder-decoder submodules model + + Args: + encoder: Instance of a module for encoding + decoder: Instance of a module for decoding + """ + super().__init__() + self.encoder = encoder + self.decoder = decoder + +If the CLI is implemented as :code:`LightningCLI(MyMainModel)` the configuration would be as follows: + +.. code-block:: yaml + + model: + encoder: + class_path: mycode.myencoders.MyEncoder + init_args: + ... + decoder: + class_path: mycode.mydecoders.MyDecoder + init_args: + ... + +It is also possible to combine :code:`subclass_mode_model=True` and submodules, thereby having two levels of +:code:`class_path`. Customizing LightningCLI @@ -275,7 +339,7 @@ extended to customize different parts of the command line tool. The argument par adding arguments can be done using the :func:`add_argument` method. In contrast to argparse it has additional methods to add arguments, for example :func:`add_class_arguments` adds all arguments from the init of a class, though requiring parameters to have type hints. For more details about this please refer to the `respective documentation -`_. +`_. The :class:`~pytorch_lightning.utilities.cli.LightningCLI` class has the :meth:`~pytorch_lightning.utilities.cli.LightningCLI.add_arguments_to_parser` method which can be implemented to include diff --git a/requirements/extra.txt b/requirements/extra.txt index e719ee3f30..98c4948125 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,4 +7,4 @@ torchtext>=0.5 # onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 -jsonargparse[signatures]>=3.9.0 +jsonargparse[signatures]>=3.10.1 diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index e6c067266c..057b73ef38 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -12,17 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import json import os import pickle import sys from argparse import Namespace +from contextlib import redirect_stdout +from io import StringIO from unittest import mock import pytest import yaml -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback @@ -329,3 +332,98 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir): assert config['model'] == cli.config['model'] assert config['data'] == cli.config['data'] assert config['trainer'] == cli.config['trainer'] + + +def any_model_any_data_cli(): + LightningCLI( + LightningModule, + LightningDataModule, + subclass_mode_model=True, + subclass_mode_data=True, + ) + + +def test_lightning_cli_help(): + + cli_args = ['any.py', '--help'] + out = StringIO() + with mock.patch('sys.argv', cli_args), redirect_stdout(out), pytest.raises(SystemExit): + any_model_any_data_cli() + + assert '--print_config' in out.getvalue() + assert '--config' in out.getvalue() + assert '--seed_everything' in out.getvalue() + assert '--model.help' in out.getvalue() + assert '--data.help' in out.getvalue() + + skip_params = {'self'} + for param in inspect.signature(Trainer.__init__).parameters.keys(): + if param not in skip_params: + assert f'--trainer.{param}' in out.getvalue() + + cli_args = ['any.py', '--data.help=tests.helpers.BoringDataModule'] + out = StringIO() + with mock.patch('sys.argv', cli_args), redirect_stdout(out), pytest.raises(SystemExit): + any_model_any_data_cli() + + assert '--data.init_args.data_dir' in out.getvalue() + + +def test_lightning_cli_print_config(): + + cli_args = [ + 'any.py', + '--seed_everything=1234', + '--model=tests.helpers.BoringModel', + '--data=tests.helpers.BoringDataModule', + '--print_config', + ] + + out = StringIO() + with mock.patch('sys.argv', cli_args), redirect_stdout(out), pytest.raises(SystemExit): + any_model_any_data_cli() + + outval = yaml.safe_load(out.getvalue()) + assert outval['seed_everything'] == 1234 + assert outval['model']['class_path'] == 'tests.helpers.BoringModel' + assert outval['data']['class_path'] == 'tests.helpers.BoringDataModule' + + +def test_lightning_cli_submodules(tmpdir): + + class MainModule(BoringModel): + def __init__( + self, + submodule1: LightningModule, + submodule2: LightningModule, + main_param: int = 1, + ): + super().__init__() + self.submodule1 = submodule1 + self.submodule2 = submodule2 + + config = """model: + main_param: 2 + submodule1: + class_path: tests.helpers.BoringModel + submodule2: + class_path: tests.helpers.BoringModel + """ + config_path = tmpdir / 'config.yaml' + with open(config_path, 'w') as f: + f.write(config) + + cli_args = [ + f'--trainer.default_root_dir={tmpdir}', + '--trainer.max_epochs=1', + f'--config={str(config_path)}', + ] + + with mock.patch('sys.argv', ['any.py'] + cli_args): + cli = LightningCLI(MainModule) + + assert cli.config_init['model']['main_param'] == 2 + assert cli.model.submodule1 == cli.config_init['model']['submodule1'] + assert cli.model.submodule2 == cli.config_init['model']['submodule2'] + assert isinstance(cli.config_init['model']['submodule1'], BoringModel) + assert isinstance(cli.config_init['model']['submodule2'], BoringModel)