2023-01-04 18:11:29 +00:00
#############
Fabric (Beta)
#############
2021-10-30 10:25:52 +00:00
2023-01-10 19:11:03 +00:00
Fabric allows you to scale any PyTorch model with just a few lines of code!
2023-01-25 10:45:09 +00:00
With Fabric, you can easily scale your model to run on distributed devices using the strategy of your choice while keeping complete control over the training loop and optimization logic.
2021-11-02 15:13:01 +00:00
2023-01-04 18:11:29 +00:00
With only a few changes to your code, Fabric allows you to:
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
- Automatic placement of models and data onto the device
- Automatic support for mixed precision (speedup and smaller memory footprint)
- Seamless switching between hardware (CPU, GPU, TPU)
- State-of-the-art distributed training strategies (DDP, FSDP, DeepSpeed)
- Easy-to-use launch command for spawning processes (DDP, torchelastic, etc)
- Multi-node support (TorchElastic, SLURM, and more)
2023-01-25 10:45:09 +00:00
- You keep complete control of your training loop
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
.. code-block :: diff
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
+ from lightning.fabric import Fabric
2021-10-30 10:25:52 +00:00
2023-01-12 12:08:32 +00:00
class PyTorchModel(nn.Module):
2023-01-04 18:11:29 +00:00
...
2021-10-30 10:25:52 +00:00
2023-01-12 12:08:32 +00:00
class PyTorchDataset(Dataset):
2023-01-04 18:11:29 +00:00
...
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
+ fabric = Fabric(accelerator="cuda", devices=8, strategy="ddp")
+ fabric.launch()
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
- device = "cuda" if torch.cuda.is_available() else "cpu
2023-01-12 12:08:32 +00:00
model = PyTorchModel(...)
2023-01-04 18:11:29 +00:00
optimizer = torch.optim.SGD(model.parameters())
+ model, optimizer = fabric.setup(model, optimizer)
2023-01-12 12:08:32 +00:00
dataloader = DataLoader(PyTorchDataset(...), ...)
2023-01-04 18:11:29 +00:00
+ dataloader = fabric.setup_dataloaders(dataloader)
model.train()
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
for epoch in range(num_epochs):
for batch in dataloader:
2023-01-12 12:08:32 +00:00
input, target = batch
- input, target = input.to(device), target.to(device)
2023-01-04 18:11:29 +00:00
optimizer.zero_grad()
2023-01-12 12:08:32 +00:00
output = model(input)
loss = loss_fn(output, target)
2023-01-04 18:11:29 +00:00
- loss.backward()
+ fabric.backward(loss)
optimizer.step()
2023-01-23 13:28:20 +00:00
lr_scheduler.step()
2021-10-30 10:25:52 +00:00
2023-01-10 19:11:03 +00:00
.. note :: Fabric is currently in Beta. Its API is subject to change based on feedback.
2021-10-30 10:25:52 +00:00
2023-01-12 13:37:24 +00:00
----
2021-10-30 10:25:52 +00:00
2023-01-04 18:11:29 +00:00
***** ***** **
2023-01-10 19:11:03 +00:00
Fundamentals
2023-01-04 18:11:29 +00:00
***** ***** **
2021-10-30 10:25:52 +00:00
2023-01-10 19:11:03 +00:00
.. raw :: html
<div class="display-card-container">
<div class="row">
.. displayitem ::
:header: Getting Started
:description: Learn how to add Fabric to your PyTorch code
:button_link: fundamentals/convert.html
:col_css: col-md-4
:height: 150
:tag: basic
.. displayitem ::
:header: Accelerators
:description: Take advantage of your hardware with a switch of a flag
:button_link: fundamentals/accelerators.html
:col_css: col-md-4
:height: 150
:tag: intermediate
.. displayitem ::
:header: Code Structure
:description: Best practices for setting up your training script with Fabric
:button_link: fundamentals/code_structure.html
:col_css: col-md-4
:height: 150
:tag: basic
.. displayitem ::
2023-01-18 22:30:51 +00:00
:header: Launch Distributed Training
2023-01-10 19:11:03 +00:00
:description: Launch a Python script on multiple devices and machines
:button_link: fundamentals/launch.html
:col_css: col-md-4
:height: 150
:tag: intermediate
.. displayitem ::
:header: Fabric in Notebooks
:description: Launch on multiple devices from within a Jupyter notebook
:button_link: fundamentals/notebooks.html
:col_css: col-md-4
:height: 150
:tag: basic
2023-01-12 13:37:24 +00:00
.. displayitem ::
:header: Mixed Precision Training
:description: Save memory and speed up training using mixed precision
:button_link: fundamentals/precision.html
:col_css: col-md-4
:height: 150
:tag: intermediate
2023-01-10 19:11:03 +00:00
.. raw :: html
</div>
</div>
2023-01-09 18:33:18 +00:00
2023-01-12 13:37:24 +00:00
----
2021-10-30 10:25:52 +00:00
2023-01-10 19:11:03 +00:00
***** ***** ***** ***** **
Build Your Own Trainer
***** ***** ***** ***** **
2021-10-30 10:25:52 +00:00
2023-01-10 19:11:03 +00:00
.. raw :: html
2021-10-30 10:25:52 +00:00
2023-01-10 19:11:03 +00:00
<div class="display-card-container">
<div class="row">
2021-10-30 10:25:52 +00:00
2023-01-10 19:11:03 +00:00
.. displayitem ::
:header: The LightningModule
:description: Organize your code in a LightningModule and use it with Fabric
:button_link: guide/lightning_module.html
:col_css: col-md-4
:height: 150
:tag: basic
2021-10-30 10:25:52 +00:00
2023-01-10 19:11:03 +00:00
.. displayitem ::
:header: Callbacks
:description: Make use of the Callback system in Fabric
:button_link: guide/callbacks.html
:col_css: col-md-4
:height: 150
:tag: basic
2021-10-30 10:25:52 +00:00
2023-01-10 19:11:03 +00:00
.. displayitem ::
:header: Logging
:description: Learn how Fabric helps you remove boilerplate code for tracking metrics with a logger
:button_link: guide/logging.html
:col_css: col-md-4
:height: 150
:tag: basic
2021-10-30 10:25:52 +00:00
2023-01-19 20:40:12 +00:00
.. displayitem ::
:header: Checkpoints
:description: Efficient saving and loading of model weights, training state, hyperparameters and more.
:button_link: guide/checkpoint.html
:col_css: col-md-4
:height: 150
:tag: basic
2023-01-10 19:11:03 +00:00
.. displayitem ::
:header: Trainer Template
:description: Take our Fabric Trainer template and customize it for your needs
:button_link: guide/trainer_template.html
:col_css: col-md-4
:height: 150
:tag: intermediate
2021-10-30 10:25:52 +00:00
2023-01-10 19:11:03 +00:00
.. raw :: html
2021-10-30 10:25:52 +00:00
2023-01-10 19:11:03 +00:00
</div>
</div>
2021-10-30 10:25:52 +00:00
2023-01-12 13:37:24 +00:00
----
2022-10-19 19:55:12 +00:00
2023-01-10 19:11:03 +00:00
***** ***** *****
Advanced Topics
***** ***** *****
2022-10-19 19:55:12 +00:00
2023-01-12 13:37:24 +00:00
.. raw :: html
2023-01-06 15:54:19 +00:00
2023-01-12 13:37:24 +00:00
<div class="display-card-container">
<div class="row">
.. displayitem ::
:header: Efficient Gradient Accumulation
:description: Learn how to perform efficient gradient accumulation in distributed settings
:button_link: advanced/gradient_accumulation.html
:col_css: col-md-4
2023-01-13 13:09:44 +00:00
:height: 160
2023-01-12 13:37:24 +00:00
:tag: advanced
.. displayitem ::
2023-01-18 22:30:51 +00:00
:header: Distributed Communication
2023-01-12 13:37:24 +00:00
:description: Learn all about communication primitives for distributed operation. Gather, reduce, broadcast, etc.
2023-01-18 22:30:51 +00:00
:button_link: advanced/distributed_communication.html
2023-01-12 13:37:24 +00:00
:col_css: col-md-4
2023-01-13 13:09:44 +00:00
:height: 160
2023-01-12 13:37:24 +00:00
:tag: advanced
.. raw :: html
</div>
</div>
2023-01-06 15:54:19 +00:00
2023-01-12 13:37:24 +00:00
----
2023-01-06 15:54:19 +00:00
2023-01-10 19:11:03 +00:00
.. _Fabric Examples:
2023-01-06 15:54:19 +00:00
2023-01-10 19:11:03 +00:00
***** ***
Examples
***** ***
2023-01-06 15:54:19 +00:00
2023-01-10 19:11:03 +00:00
.. raw :: html
2023-01-06 15:54:19 +00:00
2023-01-10 19:11:03 +00:00
<div class="display-card-container">
<div class="row">
2023-01-06 15:54:19 +00:00
2023-01-10 19:11:03 +00:00
.. displayitem ::
:header: Image Classification
:description: Train an image classifier on the MNIST dataset
2023-01-23 13:28:20 +00:00
:button_link: https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/image_classifier
2023-01-10 19:11:03 +00:00
:col_css: col-md-4
:height: 150
:tag: basic
2023-01-06 15:54:19 +00:00
2023-01-10 19:11:03 +00:00
.. displayitem ::
:header: GAN
:description: Train a GAN that generates realistic human faces
2023-01-23 13:28:20 +00:00
:button_link: https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/dcgan
2023-01-10 19:11:03 +00:00
:col_css: col-md-4
:height: 150
:tag: intermediate
2023-01-06 15:54:19 +00:00
2023-01-10 19:11:03 +00:00
.. displayitem ::
2023-01-12 14:31:34 +00:00
:header: Meta-Learning
:description: Distributed training with the MAML algorithm on the Omniglot and MiniImagenet datasets
2023-01-23 13:28:20 +00:00
:button_link: https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/meta_learning
2023-01-10 19:11:03 +00:00
:col_css: col-md-4
:height: 150
2023-01-12 14:31:34 +00:00
:tag: intermediate
2023-01-09 18:33:18 +00:00
2023-01-23 13:28:20 +00:00
.. displayitem ::
:header: Large Language Models
:description: Pre-train a GPT-2 language model on OpenWebText data
:button_link: https://github.com/Lightning-AI/nanoGPT/blob/master/train_fabric.py
:col_css: col-md-4
:height: 150
:tag: advanced
2023-01-10 19:11:03 +00:00
.. displayitem ::
2023-01-12 14:31:34 +00:00
:header: Reinforcement Learning
2023-01-27 11:28:25 +00:00
:description: Implementation of the Proximal Policy Optimization (PPO) algorithm with multi-GPU support
:button_link: https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/reinforcement_learning
2023-01-10 19:11:03 +00:00
:col_css: col-md-4
:height: 150
2023-01-09 18:33:18 +00:00
2023-01-10 19:11:03 +00:00
.. displayitem ::
2023-01-12 14:31:34 +00:00
:header: Active Learning
2023-01-10 19:11:03 +00:00
:description: Coming soon
:col_css: col-md-4
:height: 150
2023-01-09 18:33:18 +00:00
2023-01-12 14:31:34 +00:00
2023-01-10 19:11:03 +00:00
.. raw :: html
2023-01-09 18:33:18 +00:00
2023-01-10 19:11:03 +00:00
</div>
</div>
2023-01-09 18:33:18 +00:00
2023-01-12 13:37:24 +00:00
----
2023-01-09 18:33:18 +00:00
2023-01-10 19:11:03 +00:00
***
API
***
.. raw :: html
<div class="display-card-container">
<div class="row">
.. displayitem ::
:header: Fabric Arguments
:description: All configuration options for the Fabric object
:button_link: api/fabric_args.html
:col_css: col-md-4
:height: 150
:tag: basic
.. displayitem ::
:header: Fabric Methods
:description: Explore all methods that Fabric offers
:button_link: api/fabric_methods.html
:col_css: col-md-4
:height: 150
:tag: basic
.. displayitem ::
:header: Utilities
:description: Explore utility functions that make your life easier
:button_link: api/utilities.html
:col_css: col-md-4
:height: 150
:tag: basic
.. displayitem ::
:header: Full API Reference
:description: Reference of all public classes, methods and functions. Useful for developers.
:button_link: api/api_reference.html
:col_css: col-md-4
:height: 150
:tag: intermediate
.. raw :: html
</div>
</div>