lightning/docs/source-pytorch/cli/lightning_cli_intermediate_...

275 lines
7.5 KiB
ReStructuredText

:orphan:
#####################################################
Configure hyperparameters from the CLI (Intermediate)
#####################################################
**Audience:** Users who have multiple models and datasets per project.
**Pre-reqs:** You must have read :doc:`(Control it all from the CLI) <lightning_cli_intermediate>`.
----
***************************
Why mix models and datasets
***************************
Lightning projects usually begin with one model and one dataset. As the project grows in complexity and you introduce
more models and more datasets, it becomes desirable to mix any model with any dataset directly from the command line
without changing your code.
.. code:: bash
# Mix and match anything
$ python main.py fit --model=GAN --data=MNIST
$ python main.py fit --model=Transformer --data=MNIST
``LightningCLI`` makes this very simple. Otherwise, this kind of configuration requires a significant amount of
boilerplate that often looks like this:
.. code:: python
# choose model
if args.model == "gan":
model = GAN(args.feat_dim)
elif args.model == "transformer":
model = Transformer(args.feat_dim)
...
# choose datamodule
if args.data == "MNIST":
datamodule = MNIST()
elif args.data == "imagenet":
datamodule = Imagenet()
...
# mix them!
trainer.fit(model, datamodule)
It is highly recommended that you avoid writing this kind of boilerplate and use ``LightningCLI`` instead.
----
*************************
Multiple LightningModules
*************************
To support multiple models, when instantiating ``LightningCLI`` omit the ``model_class`` parameter:
.. code:: python
# main.py
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.demos.boring_classes import DemoModel
class Model1(DemoModel):
def configure_optimizers(self):
print("", "using Model1", "")
return super().configure_optimizers()
class Model2(DemoModel):
def configure_optimizers(self):
print("", "using Model2", "")
return super().configure_optimizers()
cli = LightningCLI(datamodule_class=BoringDataModule)
Now you can choose between any model from the CLI:
.. code:: bash
# use Model1
python main.py fit --model Model1
# use Model2
python main.py fit --model Model2
.. tip::
Instead of omitting the ``model_class`` parameter, you can give a base class and ``subclass_mode_model=True``. This
will make the CLI only accept models which are a subclass of the given base class.
----
*****************************
Multiple LightningDataModules
*****************************
To support multiple data modules, when instantiating ``LightningCLI`` omit the ``datamodule_class`` parameter:
.. code:: python
# main.py
import torch
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.demos.boring_classes import BoringDataModule
class FakeDataset1(BoringDataModule):
def train_dataloader(self):
print("", "using FakeDataset1", "")
return torch.utils.data.DataLoader(self.random_train)
class FakeDataset2(BoringDataModule):
def train_dataloader(self):
print("", "using FakeDataset2", "")
return torch.utils.data.DataLoader(self.random_train)
cli = LightningCLI(DemoModel)
Now you can choose between any dataset at runtime:
.. code:: bash
# use Model1
python main.py fit --data FakeDataset1
# use Model2
python main.py fit --data FakeDataset2
.. tip::
Instead of omitting the ``datamodule_class`` parameter, you can give a base class and ``subclass_mode_data=True``.
This will make the CLI only accept data modules that are a subclass of the given base class.
----
*******************
Multiple optimizers
*******************
Standard optimizers from ``torch.optim`` work out of the box:
.. code:: bash
python main.py fit --optimizer AdamW
If the optimizer you want needs other arguments, add them via the CLI (no need to change your code)!
.. code:: bash
python main.py fit --optimizer SGD --optimizer.lr=0.01
Furthermore, any custom subclass of :class:`torch.optim.Optimizer` can be used as an optimizer:
.. code:: python
# main.py
import torch
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.demos.boring_classes import DemoModel, BoringDataModule
class LitAdam(torch.optim.Adam):
def step(self, closure):
print("", "using LitAdam", "")
super().step(closure)
class FancyAdam(torch.optim.Adam):
def step(self, closure):
print("", "using FancyAdam", "")
super().step(closure)
cli = LightningCLI(DemoModel, BoringDataModule)
Now you can choose between any optimizer at runtime:
.. code:: bash
# use LitAdam
python main.py fit --optimizer LitAdam
# use FancyAdam
python main.py fit --optimizer FancyAdam
----
*******************
Multiple schedulers
*******************
Standard learning rate schedulers from ``torch.optim.lr_scheduler`` work out of the box:
.. code:: bash
python main.py fit --lr_scheduler CosineAnnealingLR
If the scheduler you want needs other arguments, add them via the CLI (no need to change your code)!
.. code:: bash
python main.py fit --lr_scheduler=ReduceLROnPlateau --lr_scheduler.monitor=epoch
Furthermore, any custom subclass of ``torch.optim.lr_scheduler.LRScheduler`` can be used as learning rate scheduler:
.. code:: python
# main.py
import torch
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.demos.boring_classes import DemoModel, BoringDataModule
class LitLRScheduler(torch.optim.lr_scheduler.CosineAnnealingLR):
def step(self):
print("", "using LitLRScheduler", "")
super().step()
cli = LightningCLI(DemoModel, BoringDataModule)
Now you can choose between any learning rate scheduler at runtime:
.. code:: bash
# LitLRScheduler
python main.py fit --lr_scheduler LitLRScheduler
----
************************
Classes from any package
************************
In the previous sections, custom classes to select were defined in the same python file where the ``LightningCLI`` class
is run. To select classes from any package by using only the class name, import the respective package:
.. code:: python
from pytorch_lightning.cli import LightningCLI
import my_code.models # noqa: F401
import my_code.data_modules # noqa: F401
import my_code.optimizers # noqa: F401
cli = LightningCLI()
Now use any of the classes:
.. code:: bash
python main.py fit --model Model1 --data FakeDataset1 --optimizer LitAdam --lr_scheduler LitLRScheduler
The ``# noqa: F401`` comment avoids a linter warning that the import is unused.
It is also possible to select subclasses that have not been imported by giving the full import path:
.. code:: bash
python main.py fit --model my_code.models.Model1
----
*************************
Help for specific classes
*************************
When multiple models or datasets are accepted, the main help of the CLI does not include their specific parameters. To
show this specific help, additional help arguments expect the class name or its import path. For example:
.. code:: bash
python main.py fit --model.help Model1
python main.py fit --data.help FakeDataset2
python main.py fit --optimizer.help Adagrad
python main.py fit --lr_scheduler.help StepLR