data:image/s3,"s3://crabby-images/93a52/93a521d0497d6a7c71addb70768ef302a513eeeb" alt=""
**Fabric is the fast and lightweight way to scale PyTorch models without boilerplate**
______________________________________________________________________
Website •
Docs •
Getting started •
FAQ •
Help •
Discord
[data:image/s3,"s3://crabby-images/36f58/36f581975badbd083cf4a3249a73ce2fc578e377" alt="PyPI - Python Version"](https://pypi.org/project/lightning_fabric/)
[data:image/s3,"s3://crabby-images/4d83b/4d83bd75d767bc818d7e7fbe471386cf3ac42065" alt="PyPI Status"](https://badge.fury.io/py/lightning_fabric)
[data:image/s3,"s3://crabby-images/fe7a0/fe7a08a4ef805a220965a54d43b2fce57aa0b2a7" alt="PyPI Status"](https://pepy.tech/project/lightning_fabric)
[data:image/s3,"s3://crabby-images/6e856/6e8560a3441418f2c81e1da70706f6b1e22079fe" alt="Conda"](https://anaconda.org/conda-forge/lightning_fabric)
## Maximum flexibility, minimum code changes
With just a few code changes, run any PyTorch model on any distributed hardware, no boilerplate!
- Easily switch from running on CPU to GPU (Apple Silicon, CUDA, …), TPU, multi-GPU or even multi-node training
- Use state-of-the-art distributed training strategies (DDP, FSDP, DeepSpeed) and mixed precision out of the box
- All the device logic boilerplate is handled for you
- Designed with multi-billion parameter models in mind
- Build your own custom Trainer using Fabric primitives for training checkpointing, logging, and more
```diff
+ import lightning as L
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
class PyTorchModel(nn.Module):
...
class PyTorchDataset(Dataset):
...
+ fabric = L.Fabric(accelerator="cuda", devices=8, strategy="ddp")
+ fabric.launch()
- device = "cuda" if torch.cuda.is_available() else "cpu
model = PyTorchModel(...)
optimizer = torch.optim.SGD(model.parameters())
+ model, optimizer = fabric.setup(model, optimizer)
dataloader = DataLoader(PyTorchDataset(...), ...)
+ dataloader = fabric.setup_dataloaders(dataloader)
model.train()
for epoch in range(num_epochs):
for batch in dataloader:
input, target = batch
- input, target = input.to(device), target.to(device)
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
- loss.backward()
+ fabric.backward(loss)
optimizer.step()
lr_scheduler.step()
```
______________________________________________________________________
# Getting started
## Install Lightning