:orphan: ############################## Multiple Models and Optimizers ############################## Fabric makes it very easy to work with multiple models and/or optimizers at once in your training workflow. Examples of where this comes in handy are Generative Adversarial Networks (GANs), Auto-encoders, meta-learning and more. ---- ************************ One model, one optimizer ************************ Fabric has a simple guideline you should follow: If you have an optimizer, you should set it up together with the model to make your code truly strategy-agnostic. .. code-block:: python import torch from lightning.fabric import Fabric fabric = Fabric() # Instantiate model and optimizer model = LitModel() optimizer = torch.optim.Adam(model.parameters()) # Set up the model and optimizer together model, optimizer = fabric.setup(model, optimizer) Depending on the selected strategy, the :meth:`~lightning.fabric.fabric.Fabric.setup` method will wrap and link the model with the optimizer. ---- ****************************** One model, multiple optimizers ****************************** You can also have multiple optimizers over a single model. This is useful if you need specific optimizers or learning rates for parts of the model. .. code-block:: python # Instantiate model and optimizers model = LitModel() optimizer1 = torch.optim.SGD(model.layer1.parameters(), lr=0.003) optimizer2 = torch.optim.SGD(model.layer2.parameters(), lr=0.01) # Set up the model and optimizers together model, optimizer1, optimizer2 = fabric.setup(model, optimizer1, optimizer2) ---- ****************************** Multiple models, one optimizer ****************************** Using a single optimizer to update multiple models is possible too. The best way to do this is to group all your individual models under one top level ``nn.Module``: .. code-block:: python class AutoEncoder(torch.nn.Module): def __init__(self): super().__init__() # Group all models under a common nn.Module self.encoder = Encoder() self.decoder = Decoder() Now all of these models can be treated as a single one: .. code-block:: python # Instantiate the big model autoencoder = AutoEncoder() optimizer = ... # Set up the model(s) and optimizer together autoencoder, optimizer = fabric.setup(autoencoder, optimizer) ---- ************************************ Multiple models, multiple optimizers ************************************ You can pair up as many models and optimizers as you want. For example, two models with one optimizer each: .. code-block:: python # Two models generator = Generator() discriminator = Discriminator() # Two optimizers optimizer_gen = torch.optim.SGD(generator.parameters(), lr=0.01) optimizer_dis = torch.optim.SGD(discriminator.parameters(), lr=0.001) # Set up generator generator, optimizer_gen = fabric.setup(generator, optimizer_gen) # Set up discriminator discriminator, optimizer_dis = fabric.setup(discriminator, optimizer_dis) For a full example of this use case, see our `GAN example `_.