lightning/docs/source-pytorch/fabric/fabric.rst

282 lines
6.5 KiB
ReStructuredText
Raw Normal View History

#############
Fabric (Beta)
#############
2021-10-30 10:25:52 +00:00
Fabric allows you to scale any PyTorch model with just a few lines of code!
With Fabric you can easily scale your model to run on distributed devices using the strategy of your choice, while keeping full control over the training loop and optimization logic.
With only a few changes to your code, Fabric allows you to:
2021-10-30 10:25:52 +00:00
- Automatic placement of models and data onto the device
- Automatic support for mixed precision (speedup and smaller memory footprint)
- Seamless switching between hardware (CPU, GPU, TPU)
- State-of-the-art distributed training strategies (DDP, FSDP, DeepSpeed)
- Easy-to-use launch command for spawning processes (DDP, torchelastic, etc)
- Multi-node support (TorchElastic, SLURM, and more)
- You keep full control of your training loop
2021-10-30 10:25:52 +00:00
.. code-block:: diff
2021-10-30 10:25:52 +00:00
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
2021-10-30 10:25:52 +00:00
+ from lightning.fabric import Fabric
2021-10-30 10:25:52 +00:00
class PyTorchModel(nn.Module):
...
2021-10-30 10:25:52 +00:00
class PyTorchDataset(Dataset):
...
2021-10-30 10:25:52 +00:00
+ fabric = Fabric(accelerator="cuda", devices=8, strategy="ddp")
+ fabric.launch()
2021-10-30 10:25:52 +00:00
- device = "cuda" if torch.cuda.is_available() else "cpu
model = PyTorchModel(...)
optimizer = torch.optim.SGD(model.parameters())
+ model, optimizer = fabric.setup(model, optimizer)
dataloader = DataLoader(PyTorchDataset(...), ...)
+ dataloader = fabric.setup_dataloaders(dataloader)
model.train()
2021-10-30 10:25:52 +00:00
for epoch in range(num_epochs):
for batch in dataloader:
input, target = batch
- input, target = input.to(device), target.to(device)
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
- loss.backward()
+ fabric.backward(loss)
optimizer.step()
2021-10-30 10:25:52 +00:00
.. note:: Fabric is currently in Beta. Its API is subject to change based on feedback.
2021-10-30 10:25:52 +00:00
----------
************
Fundamentals
************
2021-10-30 10:25:52 +00:00
.. raw:: html
<div class="display-card-container">
<div class="row">
.. displayitem::
:header: Getting Started
:description: Learn how to add Fabric to your PyTorch code
:button_link: fundamentals/convert.html
:col_css: col-md-4
:height: 150
:tag: basic
.. displayitem::
:header: Accelerators
:description: Take advantage of your hardware with a switch of a flag
:button_link: fundamentals/accelerators.html
:col_css: col-md-4
:height: 150
:tag: intermediate
.. displayitem::
:header: Code Structure
:description: Best practices for setting up your training script with Fabric
:button_link: fundamentals/code_structure.html
:col_css: col-md-4
:height: 150
:tag: basic
.. displayitem::
:header: Distributed Operation
:description: Launch a Python script on multiple devices and machines
:button_link: fundamentals/launch.html
:col_css: col-md-4
:height: 150
:tag: intermediate
.. displayitem::
:header: Fabric in Notebooks
:description: Launch on multiple devices from within a Jupyter notebook
:button_link: fundamentals/notebooks.html
:col_css: col-md-4
:height: 150
:tag: basic
.. raw:: html
</div>
</div>
2021-10-30 10:25:52 +00:00
----------
**********************
Build Your Own Trainer
**********************
2021-10-30 10:25:52 +00:00
.. raw:: html
2021-10-30 10:25:52 +00:00
<div class="display-card-container">
<div class="row">
2021-10-30 10:25:52 +00:00
.. displayitem::
:header: The LightningModule
:description: Organize your code in a LightningModule and use it with Fabric
:button_link: guide/lightning_module.html
:col_css: col-md-4
:height: 150
:tag: basic
2021-10-30 10:25:52 +00:00
.. displayitem::
:header: Callbacks
:description: Make use of the Callback system in Fabric
:button_link: guide/callbacks.html
:col_css: col-md-4
:height: 150
:tag: basic
2021-10-30 10:25:52 +00:00
.. displayitem::
:header: Logging
:description: Learn how Fabric helps you remove boilerplate code for tracking metrics with a logger
:button_link: guide/logging.html
:col_css: col-md-4
:height: 150
:tag: basic
2021-10-30 10:25:52 +00:00
.. displayitem::
:header: Trainer Template
:description: Take our Fabric Trainer template and customize it for your needs
:button_link: guide/trainer_template.html
:col_css: col-md-4
:height: 150
:tag: intermediate
2021-10-30 10:25:52 +00:00
.. raw:: html
2021-10-30 10:25:52 +00:00
</div>
</div>
2021-10-30 10:25:52 +00:00
----------
***************
Advanced Topics
***************
Comnig soon.
----------
.. _Fabric Examples:
********
Examples
********
.. raw:: html
<div class="display-card-container">
<div class="row">
.. displayitem::
:header: Image Classification
:description: Train an image classifier on the MNIST dataset
:button_link: https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/image_classifier/README.md
:col_css: col-md-4
:height: 150
:tag: basic
.. displayitem::
:header: GAN
:description: Train a GAN that generates realistic human faces
:button_link: https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/dcgan/README.md
:col_css: col-md-4
:height: 150
:tag: intermediate
.. displayitem::
:header: Reinforcement Learning
:description: Coming soon
:col_css: col-md-4
:height: 150
.. displayitem::
:header: Active Learning
:description: Coming soon
:col_css: col-md-4
:height: 150
.. displayitem::
:header: Meta Learning
:description: Coming soon
:col_css: col-md-4
:height: 150
.. raw:: html
</div>
</div>
----------
***
API
***
.. raw:: html
<div class="display-card-container">
<div class="row">
.. displayitem::
:header: Fabric Arguments
:description: All configuration options for the Fabric object
:button_link: api/fabric_args.html
:col_css: col-md-4
:height: 150
:tag: basic
.. displayitem::
:header: Fabric Methods
:description: Explore all methods that Fabric offers
:button_link: api/fabric_methods.html
:col_css: col-md-4
:height: 150
:tag: basic
.. displayitem::
:header: Utilities
:description: Explore utility functions that make your life easier
:button_link: api/utilities.html
:col_css: col-md-4
:height: 150
:tag: basic
.. displayitem::
:header: Full API Reference
:description: Reference of all public classes, methods and functions. Useful for developers.
:button_link: api/api_reference.html
:col_css: col-md-4
:height: 150
:tag: intermediate
.. raw:: html
</div>
</div>