From 0e84f01b09def0d4befd4998ee2a0570d77ed26b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 7 Mar 2023 13:19:43 +0100 Subject: [PATCH] Document how to use multiple models and optimizers in Fabric (#16952) --- .../source-fabric/advanced/multiple_setup.rst | 116 ++++++++++++++++++ docs/source-fabric/index.rst | 9 ++ 2 files changed, 125 insertions(+) create mode 100644 docs/source-fabric/advanced/multiple_setup.rst diff --git a/docs/source-fabric/advanced/multiple_setup.rst b/docs/source-fabric/advanced/multiple_setup.rst new file mode 100644 index 0000000000..1ecdcc9dee --- /dev/null +++ b/docs/source-fabric/advanced/multiple_setup.rst @@ -0,0 +1,116 @@ +:orphan: + +############################## +Multiple Models and Optimizers +############################## + +Fabric makes it very easy to work with multiple models and/or optimizers at once in your training workflow. +Examples of where this comes in handy are Generative Adversarial Networks (GANs), Auto-encoders, meta-learning and more. + + +---- + +************************ +One model, one optimizer +************************ + +Fabric has a simple guideline you should follow: +If you have an optimizer, you should set it up together with the model to make your code truly strategy-agnostic. + +.. code-block:: python + + import torch + from lightning.fabric import Fabric + + fabric = Fabric() + + # Instantiate model and optimizer + model = LitModel() + optimizer = torch.optim.Adam(model.parameters()) + + # Set up the model and optimizer together + model, optimizer = fabric.setup(model, optimizer) + + +Depending on the selected strategy, the :meth:`~lightning.fabric.fabric.Fabric.setup` method will wrap and link the model with the optimizer. + + +---- + + +****************************** +One model, multiple optimizers +****************************** + +You can also have multiple optimizers over a single model. +This is useful if you need specific optimizers or learning rates for parts of the model. + +.. code-block:: python + + # Instantiate model and optimizers + model = LitModel() + optimizer1 = torch.optim.SGD(model.layer1.parameters(), lr=0.003) + optimizer2 = torch.optim.SGD(model.layer2.parameters(), lr=0.01) + + # Set up the model and optimizers together + model, optimizer1, optimizer2 = fabric.setup(model, optimizer1, optimizer2) + + + +---- + + +****************************** +Multiple models, one optimizer +****************************** + +Using a single optimizer to update multiple models is possible too. +The best way to do this is to group all your individual models under one top level ``nn.Module``: + +.. code-block:: python + + class AutoEncoder(torch.nn.Module): + def __init__(self): + super().__init__() + + # Group all models under a common nn.Module + self.encoder = Encoder() + self.decoder = Decoder() + +Now all of these models can be treated as a single one: + +.. code-block:: python + + # Instantiate the big model + autoencoder = AutoEncoder() + optimizer = ... + + # Set up the model(s) and optimizer together + autoencoder, optimizer = fabric.setup(autoencoder, optimizer) + + +---- + + +************************************ +Multiple models, multiple optimizers +************************************ + +You can pair up as many models and optimizers as you want. For example, two models with one optimizer each: + +.. code-block:: python + + # Two models + generator = Generator() + discriminator = Discriminator() + + # Two optimizers + optimizer_gen = torch.optim.SGD(generator.parameters(), lr=0.01) + optimizer_dis = torch.optim.SGD(discriminator.parameters(), lr=0.001) + + # Set up generator + generator, optimizer_gen = fabric.setup(generator, optimizer_gen) + # Set up discriminator + discriminator, optimizer_dis = fabric.setup(discriminator, optimizer_dis) + +For a full example of this use case, see our `GAN example `_. diff --git a/docs/source-fabric/index.rst b/docs/source-fabric/index.rst index 0846cd8c0d..eca5572eca 100644 --- a/docs/source-fabric/index.rst +++ b/docs/source-fabric/index.rst @@ -231,6 +231,14 @@ Advanced Topics :height: 160 :tag: advanced +.. displayitem:: + :header: Multiple Models and Optimizers + :description: See how flexible Fabric is to work with multiple models and optimizers! + :button_link: advanced/multiple_setup.html + :col_css: col-md-4 + :height: 160 + :tag: advanced + .. raw:: html @@ -275,6 +283,7 @@ Advanced Topics Efficient Gradient Accumulation Distributed Communication + Multiple Models and Optimizers .. toctree:: :maxdepth: 1