2023-01-10 19:11:03 +00:00
##############################
Convert PyTorch code to Fabric
##############################
2023-02-27 20:14:23 +00:00
Here are five easy steps to let :class: `~lightning.fabric.fabric.Fabric` scale your PyTorch models.
2023-01-10 19:11:03 +00:00
2023-02-27 20:14:23 +00:00
**Step 1:** Create the :class: `~lightning.fabric.fabric.Fabric` object at the beginning of your training code.
2023-01-10 19:11:03 +00:00
.. code-block :: python
from lightning.fabric import Fabric
fabric = Fabric()
2023-02-27 20:14:23 +00:00
**Step 2:** Call :meth: `~lightning.fabric.fabric.Fabric.launch` if you intend to use multiple devices (e.g., multi-GPU).
2023-02-27 13:19:54 +00:00
.. code-block :: python
fabric.launch()
2023-03-07 15:43:47 +00:00
**Step 3:** Call :meth: `~lightning.fabric.fabric.Fabric.setup` on each model and optimizer pair and :meth: `~lightning.fabric.fabric.Fabric.setup_dataloaders` on all your data loaders.
2023-01-10 19:11:03 +00:00
.. code-block :: python
model, optimizer = fabric.setup(model, optimizer)
dataloader = fabric.setup_dataloaders(dataloader)
2023-02-27 20:14:23 +00:00
**Step 4:** Remove all `` .to `` and `` .cuda `` calls since :class: `~lightning.fabric.fabric.Fabric` will take care of it.
2023-01-10 19:11:03 +00:00
.. code-block :: diff
- model.to(device)
- batch.to(device)
2023-02-27 13:19:54 +00:00
**Step 5:** Replace `` loss.backward() `` by `` fabric.backward(loss) `` .
2023-01-10 19:11:03 +00:00
.. code-block :: diff
- loss.backward()
+ fabric.backward(loss)
2023-02-27 13:19:54 +00:00
These are all code changes required to prepare your script for Fabric.
You can now simply run from the terminal:
2023-01-10 19:11:03 +00:00
2023-02-27 13:19:54 +00:00
.. code-block :: bash
2023-01-10 19:11:03 +00:00
2023-02-27 13:19:54 +00:00
python path/to/your/script.py
2023-01-10 19:11:03 +00:00
|
All steps combined, this is how your code will change:
.. code-block :: diff
import torch
2023-04-06 18:32:23 +00:00
from lightning.pytorch.demos import WikiText2, Transformer
+ import lightning as L
2023-01-10 19:11:03 +00:00
2023-04-06 18:32:23 +00:00
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ fabric = L.Fabric(accelerator="cuda", devices=8, strategy="ddp")
2023-01-10 19:11:03 +00:00
+ fabric.launch()
2023-04-06 18:32:23 +00:00
dataset = WikiText2()
dataloader = torch.utils.data.DataLoader(dataset)
model = Transformer(vocab_size=dataset.vocab_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
- model = model.to(device)
2023-01-10 19:11:03 +00:00
+ model, optimizer = fabric.setup(model, optimizer)
+ dataloader = fabric.setup_dataloaders(dataloader)
2023-04-06 18:32:23 +00:00
model.train()
for epoch in range(20):
2023-01-10 19:11:03 +00:00
for batch in dataloader:
2023-01-12 12:08:32 +00:00
input, target = batch
- input, target = input.to(device), target.to(device)
2023-01-10 19:11:03 +00:00
optimizer.zero_grad()
2023-04-06 18:32:23 +00:00
output = model(input, target)
2023-04-11 23:05:57 +00:00
loss = torch.nn.functional.nll_loss(output, target.view(-1))
2023-01-10 19:11:03 +00:00
- loss.backward()
+ fabric.backward(loss)
optimizer.step()
That's it! You can now train on any device at any scale with a switch of a flag.
2023-03-07 15:43:47 +00:00
Check out our before-and-after example for `image classification <https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/image_classifier/README.md> `_ and many more :doc: `examples <../examples/index>` that use Fabric.
2023-01-12 12:08:32 +00:00
2023-03-09 12:28:06 +00:00
----
2024-02-17 23:37:35 +00:00
***** ***** ***** *
Optional changes
***** ***** ***** *
Here are a few optional upgrades you can make to your code, if applicable:
- Replace `` torch.save() `` and `` torch.load() `` with Fabric's :doc: `save and load methods <../guide/checkpoint/checkpoint>` .
- Replace collective operations from `` torch.distributed `` (barrier, broadcast, etc.) with Fabric's :doc: `collective methods <../advanced/distributed_communication>` .
- Use Fabric's :doc: `no_backward_sync() context manager <../advanced/gradient_accumulation>` if you implemented gradient accumulation.
- Initialize your model under the :doc: `init_module() <../advanced/model_init>` context manager.
----
2023-01-12 12:08:32 +00:00
***** *****
Next steps
***** *****
.. raw :: html
<div class="display-card-container">
<div class="row">
.. displayitem ::
:header: Examples
:description: See examples across computer vision, NLP, RL, etc.
:col_css: col-md-4
2023-03-07 15:43:47 +00:00
:button_link: ../examples/index.html
2023-01-12 12:08:32 +00:00
:height: 150
:tag: basic
.. displayitem ::
:header: Accelerators
:description: Take advantage of your hardware with a switch of a flag
:button_link: accelerators.html
:col_css: col-md-4
:height: 150
2023-03-17 08:42:58 +00:00
:tag: basic
2023-01-12 12:08:32 +00:00
.. displayitem ::
:header: Build your own Trainer
:description: Learn how to build a trainer tailored for you
:col_css: col-md-4
2023-03-17 08:42:58 +00:00
:button_link: ../levels/intermediate
2023-01-12 12:08:32 +00:00
:height: 150
:tag: intermediate
.. raw :: html
</div>
</div>