267 lines
10 KiB
ReStructuredText
267 lines
10 KiB
ReStructuredText
:orphan:
|
|
|
|
.. testsetup:: *
|
|
:skipif: not _JSONARGPARSE_AVAILABLE
|
|
|
|
import torch
|
|
from unittest import mock
|
|
from typing import List
|
|
import pytorch_lightning as pl
|
|
from pytorch_lightning import LightningModule, LightningDataModule, Trainer, Callback
|
|
|
|
|
|
class NoFitTrainer(Trainer):
|
|
def fit(self, *_, **__):
|
|
pass
|
|
|
|
|
|
class LightningCLI(pl.cli.LightningCLI):
|
|
def __init__(self, *args, trainer_class=NoFitTrainer, run=False, **kwargs):
|
|
super().__init__(*args, trainer_class=trainer_class, run=run, **kwargs)
|
|
|
|
|
|
class MyModel(LightningModule):
|
|
def __init__(
|
|
self,
|
|
encoder_layers: int = 12,
|
|
decoder_layers: List[int] = [2, 4],
|
|
batch_size: int = 8,
|
|
):
|
|
pass
|
|
|
|
|
|
class MyClassModel(LightningModule):
|
|
def __init__(self, num_classes: int):
|
|
pass
|
|
|
|
|
|
class MyDataModule(LightningDataModule):
|
|
def __init__(self, batch_size: int = 8):
|
|
self.num_classes = 5
|
|
|
|
|
|
def send_email(address, message):
|
|
pass
|
|
|
|
|
|
mock_argv = mock.patch("sys.argv", ["any.py"])
|
|
mock_argv.start()
|
|
|
|
.. testcleanup:: *
|
|
|
|
mock_argv.stop()
|
|
|
|
#######################################
|
|
Eliminate config boilerplate (Advanced)
|
|
#######################################
|
|
**Audience:** Users who already understand the LightningCLI and want to customize it.
|
|
|
|
----
|
|
|
|
**************************
|
|
Customize the LightningCLI
|
|
**************************
|
|
|
|
The init parameters of the :class:`~pytorch_lightning.cli.LightningCLI` class can be used to customize some
|
|
things, namely: the description of the tool, enabling parsing of environment variables and additional arguments to
|
|
instantiate the trainer and configuration parser.
|
|
|
|
Nevertheless the init arguments are not enough for many use cases. For this reason the class is designed so that can be
|
|
extended to customize different parts of the command line tool. The argument parser class used by
|
|
:class:`~pytorch_lightning.cli.LightningCLI` is
|
|
:class:`~pytorch_lightning.cli.LightningArgumentParser` which is an extension of python's argparse, thus
|
|
adding arguments can be done using the :func:`add_argument` method. In contrast to argparse it has additional methods to
|
|
add arguments, for example :func:`add_class_arguments` adds all arguments from the init of a class, though requiring
|
|
parameters to have type hints. For more details about this please refer to the `respective documentation
|
|
<https://jsonargparse.readthedocs.io/en/stable/#classes-methods-and-functions>`_.
|
|
|
|
The :class:`~pytorch_lightning.cli.LightningCLI` class has the
|
|
:meth:`~pytorch_lightning.cli.LightningCLI.add_arguments_to_parser` method which can be implemented to include
|
|
more arguments. After parsing, the configuration is stored in the :code:`config` attribute of the class instance. The
|
|
:class:`~pytorch_lightning.cli.LightningCLI` class also has two methods that can be used to run code before
|
|
and after the trainer runs: :code:`before_<subcommand>` and :code:`after_<subcommand>`.
|
|
A realistic example for these would be to send an email before and after the execution.
|
|
The code for the :code:`fit` subcommand would be something like:
|
|
|
|
.. testcode::
|
|
|
|
class MyLightningCLI(LightningCLI):
|
|
def add_arguments_to_parser(self, parser):
|
|
parser.add_argument("--notification_email", default="will@email.com")
|
|
|
|
def before_fit(self):
|
|
send_email(address=self.config["notification_email"], message="trainer.fit starting")
|
|
|
|
def after_fit(self):
|
|
send_email(address=self.config["notification_email"], message="trainer.fit finished")
|
|
|
|
|
|
cli = MyLightningCLI(MyModel)
|
|
|
|
Note that the config object :code:`self.config` is a dictionary whose keys are global options or groups of options. It
|
|
has the same structure as the yaml format described previously. This means for instance that the parameters used for
|
|
instantiating the trainer class can be found in :code:`self.config['fit']['trainer']`.
|
|
|
|
.. tip::
|
|
|
|
Have a look at the :class:`~pytorch_lightning.cli.LightningCLI` class API reference to learn about other
|
|
methods that can be extended to customize a CLI.
|
|
|
|
----
|
|
|
|
**************************
|
|
Configure forced callbacks
|
|
**************************
|
|
As explained previously, any Lightning callback can be added by passing it through command line or
|
|
including it in the config via :code:`class_path` and :code:`init_args` entries.
|
|
|
|
However, certain callbacks MUST be coupled with a model so they are always present and configurable.
|
|
This can be implemented as follows:
|
|
|
|
.. testcode::
|
|
|
|
from pytorch_lightning.callbacks import EarlyStopping
|
|
|
|
|
|
class MyLightningCLI(LightningCLI):
|
|
def add_arguments_to_parser(self, parser):
|
|
parser.add_lightning_class_args(EarlyStopping, "my_early_stopping")
|
|
parser.set_defaults({"my_early_stopping.monitor": "val_loss", "my_early_stopping.patience": 5})
|
|
|
|
|
|
cli = MyLightningCLI(MyModel)
|
|
|
|
To change the configuration of the :code:`EarlyStopping` in the config it would be:
|
|
|
|
.. code-block:: yaml
|
|
|
|
model:
|
|
...
|
|
trainer:
|
|
...
|
|
my_early_stopping:
|
|
patience: 5
|
|
|
|
.. note::
|
|
|
|
The example above overrides a default in :code:`add_arguments_to_parser`. This is included to show that defaults can
|
|
be changed if needed. However, note that overriding of defaults in the source code is not intended to be used to
|
|
store the best hyperparameters for a task after experimentation. To ease reproducibility the source code should be
|
|
stable. It is better practice to store the best hyperparameters for a task in a configuration file independent from
|
|
the source code.
|
|
|
|
----
|
|
|
|
*******************
|
|
Class type defaults
|
|
*******************
|
|
|
|
The support for classes as type hints allows to try many possibilities with the same CLI. This is a useful feature, but
|
|
it can make it tempting to use an instance of a class as a default. For example:
|
|
|
|
.. testcode::
|
|
|
|
class MyMainModel(LightningModule):
|
|
def __init__(
|
|
self,
|
|
backbone: torch.nn.Module = MyModel(encoder_layers=24), # BAD PRACTICE!
|
|
):
|
|
super().__init__()
|
|
self.backbone = backbone
|
|
|
|
Normally classes are mutable as it is in this case. The instance of :code:`MyModel` would be created the moment that the
|
|
module that defines :code:`MyMainModel` is first imported. This means that the default of :code:`backbone` will be
|
|
initialized before the CLI class runs :code:`seed_everything` making it non-reproducible. Furthermore, if
|
|
:code:`MyMainModel` is used more than once in the same Python process and the :code:`backbone` parameter is not
|
|
overridden, the same instance would be used in multiple places which very likely is not what the developer intended.
|
|
Having an instance as default also makes it impossible to generate the complete config file since for arbitrary classes
|
|
it is not known which arguments were used to instantiate it.
|
|
|
|
A good solution to these problems is to not have a default or set the default to a special value (e.g. a
|
|
string) which would be checked in the init and instantiated accordingly. If a class parameter has no default and the CLI
|
|
is subclassed then a default can be set as follows:
|
|
|
|
.. testcode::
|
|
|
|
default_backbone = {
|
|
"class_path": "import.path.of.MyModel",
|
|
"init_args": {
|
|
"encoder_layers": 24,
|
|
},
|
|
}
|
|
|
|
|
|
class MyLightningCLI(LightningCLI):
|
|
def add_arguments_to_parser(self, parser):
|
|
parser.set_defaults({"model.backbone": default_backbone})
|
|
|
|
A more compact version that avoids writing a dictionary would be:
|
|
|
|
.. testcode::
|
|
|
|
from jsonargparse import lazy_instance
|
|
|
|
|
|
class MyLightningCLI(LightningCLI):
|
|
def add_arguments_to_parser(self, parser):
|
|
parser.set_defaults({"model.backbone": lazy_instance(MyModel, encoder_layers=24)})
|
|
|
|
----
|
|
|
|
************************
|
|
Connect two config files
|
|
************************
|
|
Another case in which it might be desired to extend :class:`~pytorch_lightning.cli.LightningCLI` is that the
|
|
model and data module depend on a common parameter. For example in some cases both classes require to know the
|
|
:code:`batch_size`. It is a burden and error prone giving the same value twice in a config file. To avoid this the
|
|
parser can be configured so that a value is only given once and then propagated accordingly. With a tool implemented
|
|
like shown below, the :code:`batch_size` only has to be provided in the :code:`data` section of the config.
|
|
|
|
.. testcode::
|
|
|
|
class MyLightningCLI(LightningCLI):
|
|
def add_arguments_to_parser(self, parser):
|
|
parser.link_arguments("data.batch_size", "model.batch_size")
|
|
|
|
|
|
cli = MyLightningCLI(MyModel, MyDataModule)
|
|
|
|
The linking of arguments is observed in the help of the tool, which for this example would look like:
|
|
|
|
.. code-block:: bash
|
|
|
|
$ python trainer.py fit --help
|
|
...
|
|
--data.batch_size BATCH_SIZE
|
|
Number of samples in a batch (type: int, default: 8)
|
|
|
|
Linked arguments:
|
|
model.batch_size <-- data.batch_size
|
|
Number of samples in a batch (type: int)
|
|
|
|
Sometimes a parameter value is only available after class instantiation. An example could be that your model requires
|
|
the number of classes to instantiate its fully connected layer (for a classification task) but the value is not
|
|
available until the data module has been instantiated. The code below illustrates how to address this.
|
|
|
|
.. testcode::
|
|
|
|
class MyLightningCLI(LightningCLI):
|
|
def add_arguments_to_parser(self, parser):
|
|
parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate")
|
|
|
|
|
|
cli = MyLightningCLI(MyClassModel, MyDataModule)
|
|
|
|
Instantiation links are used to automatically determine the order of instantiation, in this case data first.
|
|
|
|
.. tip::
|
|
|
|
The linking of arguments can be used for more complex cases. For example to derive a value via a function that takes
|
|
multiple settings as input. For more details have a look at the API of `link_arguments
|
|
<https://jsonargparse.readthedocs.io/en/stable/#jsonargparse.ArgumentParser.link_arguments>`_.
|
|
|
|
|
|
The linking of arguments is intended for things that are meant to be non-configurable. This improves the CLI user
|
|
experience since it avoids the need for providing more parameters. A related concept is
|
|
variable interpolation which in contrast keeps things being configurable.
|