Improve `LightningCLI` documentation and tests (#7156)

* - Added cli unit tests for help, print_config and submodules.
- Added to cli documentation use of subclass help and print_config, submodules and other minor improvements.
- Increased minimum jsonargparse version required for new documented features.

* Improvements to lightning_cli.rst

* Add check for all trainer parameters in test_lightning_cli_help

* Increased minimum jsonargparse version

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Mauricio Villegas 2021-04-28 10:34:32 +02:00 committed by GitHub
parent d123aaa6a1
commit 78d45a1134
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 168 additions and 6 deletions

View File

@ -33,6 +33,9 @@
MyModelBaseClass = MyModel
MyDataModuleBaseClass = MyDataModule
EncoderBaseClass = MyModel
DecoderBaseClass = MyModel
mock_argv = mock.patch("sys.argv", ["any.py"])
mock_argv.start()
@ -116,7 +119,7 @@ The start of a possible implementation of :class:`MyModel` including the recomme
docstring could be the one below. Note that by using type hints and docstrings there is no need to duplicate this
information to define its configurable arguments.
.. code-block:: python
.. testcode::
class MyModel(LightningModule):
@ -131,7 +134,8 @@ information to define its configurable arguments.
encoder_layers: Number of layers for the encoder
decoder_layers: Number of layers for each decoder block
"""
...
super().__init__()
self.save_hyperparameters()
With this model class, the help of the trainer tool would look as follows:
@ -258,7 +262,67 @@ A possible config file could be as follows:
...
Only model classes that are a subclass of :code:`MyModelBaseClass` would be allowed, and similarly only subclasses of
:code:`MyDataModuleBaseClass`.
:code:`MyDataModuleBaseClass`. If as base classes :class:`~pytorch_lightning.core.lightning.LightningModule` and
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` are given, then the tool would allow any lightning
module and data module.
.. tip::
Note that with the subclass modes the :code:`--help` option does not show information for a specific subclass. To
get help for a subclass the options :code:`--model.help` and :code:`--data.help` can be used, followed by the
desired class path. Similarly :code:`--print_config` does not include the settings for a particular subclass. To
include them the class path should be given before the :code:`--print_config` option. Examples for both help and
print config are:
.. code-block:: bash
$ python trainer.py --model.help mycode.mymodels.MyModel
$ python trainer.py --model mycode.mymodels.MyModel --print_config
Models with multiple submodules
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Many use cases require to have several modules each with its own configurable options. One possible way to handle this
with LightningCLI is to implement a single module having as init parameters each of the submodules. Since the init
parameters have as type a class, then in the configuration these would be specified with :code:`class_path` and
:code:`init_args` entries. For instance a model could be implemented as:
.. testcode::
class MyMainModel(LightningModule):
def __init__(
self,
encoder: EncoderBaseClass,
decoder: DecoderBaseClass
):
"""Example encoder-decoder submodules model
Args:
encoder: Instance of a module for encoding
decoder: Instance of a module for decoding
"""
super().__init__()
self.encoder = encoder
self.decoder = decoder
If the CLI is implemented as :code:`LightningCLI(MyMainModel)` the configuration would be as follows:
.. code-block:: yaml
model:
encoder:
class_path: mycode.myencoders.MyEncoder
init_args:
...
decoder:
class_path: mycode.mydecoders.MyDecoder
init_args:
...
It is also possible to combine :code:`subclass_mode_model=True` and submodules, thereby having two levels of
:code:`class_path`.
Customizing LightningCLI
@ -275,7 +339,7 @@ extended to customize different parts of the command line tool. The argument par
adding arguments can be done using the :func:`add_argument` method. In contrast to argparse it has additional methods to
add arguments, for example :func:`add_class_arguments` adds all arguments from the init of a class, though requiring
parameters to have type hints. For more details about this please refer to the `respective documentation
<https://omni-us.github.io/jsonargparse/#classes-methods-and-functions>`_.
<https://jsonargparse.readthedocs.io/en/stable/#classes-methods-and-functions>`_.
The :class:`~pytorch_lightning.utilities.cli.LightningCLI` class has the
:meth:`~pytorch_lightning.utilities.cli.LightningCLI.add_arguments_to_parser` method which can be implemented to include

View File

@ -7,4 +7,4 @@ torchtext>=0.5
# onnx>=1.7.0
onnxruntime>=1.3.0
hydra-core>=1.0
jsonargparse[signatures]>=3.9.0
jsonargparse[signatures]>=3.10.1

View File

@ -12,17 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
import pickle
import sys
from argparse import Namespace
from contextlib import redirect_stdout
from io import StringIO
from unittest import mock
import pytest
import yaml
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback
@ -329,3 +332,98 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir):
assert config['model'] == cli.config['model']
assert config['data'] == cli.config['data']
assert config['trainer'] == cli.config['trainer']
def any_model_any_data_cli():
LightningCLI(
LightningModule,
LightningDataModule,
subclass_mode_model=True,
subclass_mode_data=True,
)
def test_lightning_cli_help():
cli_args = ['any.py', '--help']
out = StringIO()
with mock.patch('sys.argv', cli_args), redirect_stdout(out), pytest.raises(SystemExit):
any_model_any_data_cli()
assert '--print_config' in out.getvalue()
assert '--config' in out.getvalue()
assert '--seed_everything' in out.getvalue()
assert '--model.help' in out.getvalue()
assert '--data.help' in out.getvalue()
skip_params = {'self'}
for param in inspect.signature(Trainer.__init__).parameters.keys():
if param not in skip_params:
assert f'--trainer.{param}' in out.getvalue()
cli_args = ['any.py', '--data.help=tests.helpers.BoringDataModule']
out = StringIO()
with mock.patch('sys.argv', cli_args), redirect_stdout(out), pytest.raises(SystemExit):
any_model_any_data_cli()
assert '--data.init_args.data_dir' in out.getvalue()
def test_lightning_cli_print_config():
cli_args = [
'any.py',
'--seed_everything=1234',
'--model=tests.helpers.BoringModel',
'--data=tests.helpers.BoringDataModule',
'--print_config',
]
out = StringIO()
with mock.patch('sys.argv', cli_args), redirect_stdout(out), pytest.raises(SystemExit):
any_model_any_data_cli()
outval = yaml.safe_load(out.getvalue())
assert outval['seed_everything'] == 1234
assert outval['model']['class_path'] == 'tests.helpers.BoringModel'
assert outval['data']['class_path'] == 'tests.helpers.BoringDataModule'
def test_lightning_cli_submodules(tmpdir):
class MainModule(BoringModel):
def __init__(
self,
submodule1: LightningModule,
submodule2: LightningModule,
main_param: int = 1,
):
super().__init__()
self.submodule1 = submodule1
self.submodule2 = submodule2
config = """model:
main_param: 2
submodule1:
class_path: tests.helpers.BoringModel
submodule2:
class_path: tests.helpers.BoringModel
"""
config_path = tmpdir / 'config.yaml'
with open(config_path, 'w') as f:
f.write(config)
cli_args = [
f'--trainer.default_root_dir={tmpdir}',
'--trainer.max_epochs=1',
f'--config={str(config_path)}',
]
with mock.patch('sys.argv', ['any.py'] + cli_args):
cli = LightningCLI(MainModule)
assert cli.config_init['model']['main_param'] == 2
assert cli.model.submodule1 == cli.config_init['model']['submodule1']
assert cli.model.submodule2 == cli.config_init['model']['submodule2']
assert isinstance(cli.config_init['model']['submodule1'], BoringModel)
assert isinstance(cli.config_init['model']['submodule2'], BoringModel)