Improve `LightningCLI` documentation and tests (#7156)
* - Added cli unit tests for help, print_config and submodules. - Added to cli documentation use of subclass help and print_config, submodules and other minor improvements. - Increased minimum jsonargparse version required for new documented features. * Improvements to lightning_cli.rst * Add check for all trainer parameters in test_lightning_cli_help * Increased minimum jsonargparse version Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
d123aaa6a1
commit
78d45a1134
|
@ -33,6 +33,9 @@
|
|||
MyModelBaseClass = MyModel
|
||||
MyDataModuleBaseClass = MyDataModule
|
||||
|
||||
EncoderBaseClass = MyModel
|
||||
DecoderBaseClass = MyModel
|
||||
|
||||
mock_argv = mock.patch("sys.argv", ["any.py"])
|
||||
mock_argv.start()
|
||||
|
||||
|
@ -116,7 +119,7 @@ The start of a possible implementation of :class:`MyModel` including the recomme
|
|||
docstring could be the one below. Note that by using type hints and docstrings there is no need to duplicate this
|
||||
information to define its configurable arguments.
|
||||
|
||||
.. code-block:: python
|
||||
.. testcode::
|
||||
|
||||
class MyModel(LightningModule):
|
||||
|
||||
|
@ -131,7 +134,8 @@ information to define its configurable arguments.
|
|||
encoder_layers: Number of layers for the encoder
|
||||
decoder_layers: Number of layers for each decoder block
|
||||
"""
|
||||
...
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
||||
With this model class, the help of the trainer tool would look as follows:
|
||||
|
||||
|
@ -258,7 +262,67 @@ A possible config file could be as follows:
|
|||
...
|
||||
|
||||
Only model classes that are a subclass of :code:`MyModelBaseClass` would be allowed, and similarly only subclasses of
|
||||
:code:`MyDataModuleBaseClass`.
|
||||
:code:`MyDataModuleBaseClass`. If as base classes :class:`~pytorch_lightning.core.lightning.LightningModule` and
|
||||
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` are given, then the tool would allow any lightning
|
||||
module and data module.
|
||||
|
||||
.. tip::
|
||||
|
||||
Note that with the subclass modes the :code:`--help` option does not show information for a specific subclass. To
|
||||
get help for a subclass the options :code:`--model.help` and :code:`--data.help` can be used, followed by the
|
||||
desired class path. Similarly :code:`--print_config` does not include the settings for a particular subclass. To
|
||||
include them the class path should be given before the :code:`--print_config` option. Examples for both help and
|
||||
print config are:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ python trainer.py --model.help mycode.mymodels.MyModel
|
||||
$ python trainer.py --model mycode.mymodels.MyModel --print_config
|
||||
|
||||
|
||||
Models with multiple submodules
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Many use cases require to have several modules each with its own configurable options. One possible way to handle this
|
||||
with LightningCLI is to implement a single module having as init parameters each of the submodules. Since the init
|
||||
parameters have as type a class, then in the configuration these would be specified with :code:`class_path` and
|
||||
:code:`init_args` entries. For instance a model could be implemented as:
|
||||
|
||||
.. testcode::
|
||||
|
||||
class MyMainModel(LightningModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: EncoderBaseClass,
|
||||
decoder: DecoderBaseClass
|
||||
):
|
||||
"""Example encoder-decoder submodules model
|
||||
|
||||
Args:
|
||||
encoder: Instance of a module for encoding
|
||||
decoder: Instance of a module for decoding
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
|
||||
If the CLI is implemented as :code:`LightningCLI(MyMainModel)` the configuration would be as follows:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
model:
|
||||
encoder:
|
||||
class_path: mycode.myencoders.MyEncoder
|
||||
init_args:
|
||||
...
|
||||
decoder:
|
||||
class_path: mycode.mydecoders.MyDecoder
|
||||
init_args:
|
||||
...
|
||||
|
||||
It is also possible to combine :code:`subclass_mode_model=True` and submodules, thereby having two levels of
|
||||
:code:`class_path`.
|
||||
|
||||
|
||||
Customizing LightningCLI
|
||||
|
@ -275,7 +339,7 @@ extended to customize different parts of the command line tool. The argument par
|
|||
adding arguments can be done using the :func:`add_argument` method. In contrast to argparse it has additional methods to
|
||||
add arguments, for example :func:`add_class_arguments` adds all arguments from the init of a class, though requiring
|
||||
parameters to have type hints. For more details about this please refer to the `respective documentation
|
||||
<https://omni-us.github.io/jsonargparse/#classes-methods-and-functions>`_.
|
||||
<https://jsonargparse.readthedocs.io/en/stable/#classes-methods-and-functions>`_.
|
||||
|
||||
The :class:`~pytorch_lightning.utilities.cli.LightningCLI` class has the
|
||||
:meth:`~pytorch_lightning.utilities.cli.LightningCLI.add_arguments_to_parser` method which can be implemented to include
|
||||
|
|
|
@ -7,4 +7,4 @@ torchtext>=0.5
|
|||
# onnx>=1.7.0
|
||||
onnxruntime>=1.3.0
|
||||
hydra-core>=1.0
|
||||
jsonargparse[signatures]>=3.9.0
|
||||
jsonargparse[signatures]>=3.10.1
|
||||
|
|
|
@ -12,17 +12,20 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
from argparse import Namespace
|
||||
from contextlib import redirect_stdout
|
||||
from io import StringIO
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
|
||||
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback
|
||||
|
@ -329,3 +332,98 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir):
|
|||
assert config['model'] == cli.config['model']
|
||||
assert config['data'] == cli.config['data']
|
||||
assert config['trainer'] == cli.config['trainer']
|
||||
|
||||
|
||||
def any_model_any_data_cli():
|
||||
LightningCLI(
|
||||
LightningModule,
|
||||
LightningDataModule,
|
||||
subclass_mode_model=True,
|
||||
subclass_mode_data=True,
|
||||
)
|
||||
|
||||
|
||||
def test_lightning_cli_help():
|
||||
|
||||
cli_args = ['any.py', '--help']
|
||||
out = StringIO()
|
||||
with mock.patch('sys.argv', cli_args), redirect_stdout(out), pytest.raises(SystemExit):
|
||||
any_model_any_data_cli()
|
||||
|
||||
assert '--print_config' in out.getvalue()
|
||||
assert '--config' in out.getvalue()
|
||||
assert '--seed_everything' in out.getvalue()
|
||||
assert '--model.help' in out.getvalue()
|
||||
assert '--data.help' in out.getvalue()
|
||||
|
||||
skip_params = {'self'}
|
||||
for param in inspect.signature(Trainer.__init__).parameters.keys():
|
||||
if param not in skip_params:
|
||||
assert f'--trainer.{param}' in out.getvalue()
|
||||
|
||||
cli_args = ['any.py', '--data.help=tests.helpers.BoringDataModule']
|
||||
out = StringIO()
|
||||
with mock.patch('sys.argv', cli_args), redirect_stdout(out), pytest.raises(SystemExit):
|
||||
any_model_any_data_cli()
|
||||
|
||||
assert '--data.init_args.data_dir' in out.getvalue()
|
||||
|
||||
|
||||
def test_lightning_cli_print_config():
|
||||
|
||||
cli_args = [
|
||||
'any.py',
|
||||
'--seed_everything=1234',
|
||||
'--model=tests.helpers.BoringModel',
|
||||
'--data=tests.helpers.BoringDataModule',
|
||||
'--print_config',
|
||||
]
|
||||
|
||||
out = StringIO()
|
||||
with mock.patch('sys.argv', cli_args), redirect_stdout(out), pytest.raises(SystemExit):
|
||||
any_model_any_data_cli()
|
||||
|
||||
outval = yaml.safe_load(out.getvalue())
|
||||
assert outval['seed_everything'] == 1234
|
||||
assert outval['model']['class_path'] == 'tests.helpers.BoringModel'
|
||||
assert outval['data']['class_path'] == 'tests.helpers.BoringDataModule'
|
||||
|
||||
|
||||
def test_lightning_cli_submodules(tmpdir):
|
||||
|
||||
class MainModule(BoringModel):
|
||||
def __init__(
|
||||
self,
|
||||
submodule1: LightningModule,
|
||||
submodule2: LightningModule,
|
||||
main_param: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.submodule1 = submodule1
|
||||
self.submodule2 = submodule2
|
||||
|
||||
config = """model:
|
||||
main_param: 2
|
||||
submodule1:
|
||||
class_path: tests.helpers.BoringModel
|
||||
submodule2:
|
||||
class_path: tests.helpers.BoringModel
|
||||
"""
|
||||
config_path = tmpdir / 'config.yaml'
|
||||
with open(config_path, 'w') as f:
|
||||
f.write(config)
|
||||
|
||||
cli_args = [
|
||||
f'--trainer.default_root_dir={tmpdir}',
|
||||
'--trainer.max_epochs=1',
|
||||
f'--config={str(config_path)}',
|
||||
]
|
||||
|
||||
with mock.patch('sys.argv', ['any.py'] + cli_args):
|
||||
cli = LightningCLI(MainModule)
|
||||
|
||||
assert cli.config_init['model']['main_param'] == 2
|
||||
assert cli.model.submodule1 == cli.config_init['model']['submodule1']
|
||||
assert cli.model.submodule2 == cli.config_init['model']['submodule2']
|
||||
assert isinstance(cli.config_init['model']['submodule1'], BoringModel)
|
||||
assert isinstance(cli.config_init['model']['submodule2'], BoringModel)
|
||||
|
|
Loading…
Reference in New Issue