2023-01-10 19:11:03 +00:00
:orphan:
##############################
Convert PyTorch code to Fabric
##############################
Here are five easy steps to let :class: `~lightning_fabric.fabric.Fabric` scale your PyTorch models.
**Step 1:** Create the :class: `~lightning_fabric.fabric.Fabric` object at the beginning of your training code.
.. code-block :: python
from lightning.fabric import Fabric
fabric = Fabric()
2023-01-25 10:45:09 +00:00
**Step 2:** Call :meth: `~lightning_fabric.fabric.Fabric.setup` on each model and optimizer pair and :meth: `~lightning_fabric.fabric.Fabric.setup_dataloaders` on all your data loaders.
2023-01-10 19:11:03 +00:00
.. code-block :: python
model, optimizer = fabric.setup(model, optimizer)
dataloader = fabric.setup_dataloaders(dataloader)
**Step 3:** Remove all `` .to `` and `` .cuda `` calls since :class: `~lightning_fabric.fabric.Fabric` will take care of it.
.. code-block :: diff
- model.to(device)
- batch.to(device)
**Step 4:** Replace `` loss.backward() `` by `` fabric.backward(loss) `` .
.. code-block :: diff
- loss.backward()
+ fabric.backward(loss)
**Step 5:** Run the script from the terminal with
.. code-block :: bash
2023-01-12 12:08:32 +00:00
lightning run model path/to/train.py
2023-01-10 19:11:03 +00:00
or use the :meth: `~lightning_fabric.fabric.Fabric.launch` method in a notebook.
Learn more about :doc: `launching distributed training <launch>` .
|
All steps combined, this is how your code will change:
.. code-block :: diff
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
+ from lightning.fabric import Fabric
2023-01-12 12:08:32 +00:00
class PyTorchModel(nn.Module):
2023-01-10 19:11:03 +00:00
...
2023-01-12 12:08:32 +00:00
class PyTorchDataset(Dataset):
2023-01-10 19:11:03 +00:00
...
+ fabric = Fabric(accelerator="cuda", devices=8, strategy="ddp")
+ fabric.launch()
- device = "cuda" if torch.cuda.is_available() else "cpu
2023-01-12 12:08:32 +00:00
model = PyTorchModel(...)
2023-01-10 19:11:03 +00:00
optimizer = torch.optim.SGD(model.parameters())
+ model, optimizer = fabric.setup(model, optimizer)
2023-01-12 12:08:32 +00:00
dataloader = DataLoader(PyTorchDataset(...), ...)
2023-01-10 19:11:03 +00:00
+ dataloader = fabric.setup_dataloaders(dataloader)
model.train()
for epoch in range(num_epochs):
for batch in dataloader:
2023-01-12 12:08:32 +00:00
input, target = batch
- input, target = input.to(device), target.to(device)
2023-01-10 19:11:03 +00:00
optimizer.zero_grad()
2023-01-12 12:08:32 +00:00
output = model(input)
loss = loss_fn(output, target)
2023-01-10 19:11:03 +00:00
- loss.backward()
+ fabric.backward(loss)
optimizer.step()
2023-01-23 13:28:20 +00:00
lr_scheduler.step()
2023-01-10 19:11:03 +00:00
That's it! You can now train on any device at any scale with a switch of a flag.
Check out our before-and-after example for `image classification <https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/image_classifier/README.md> `_ and many more :ref: `examples <Fabric Examples>` that use Fabric.
2023-01-12 12:08:32 +00:00
***** *****
Next steps
***** *****
.. raw :: html
<div class="display-card-container">
<div class="row">
.. displayitem ::
:header: Examples
:description: See examples across computer vision, NLP, RL, etc.
:col_css: col-md-4
:button_link: ../fabric.html#examples
:height: 150
:tag: basic
.. displayitem ::
:header: Accelerators
:description: Take advantage of your hardware with a switch of a flag
:button_link: accelerators.html
:col_css: col-md-4
:height: 150
:tag: intermediate
.. displayitem ::
:header: Build your own Trainer
:description: Learn how to build a trainer tailored for you
:col_css: col-md-4
:button_link: ../fabric.html#build-your-own-trainer
:height: 150
:tag: intermediate
.. raw :: html
</div>
</div>