Address feedback for new Lite docs (#16330)

This commit is contained in:
Adrian Wälchli 2023-01-12 13:08:32 +01:00 committed by GitHub
parent 0876a6412d
commit 426c463721
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 13 deletions

View File

@ -24,28 +24,30 @@ With only a few changes to your code, Fabric allows you to:
+ from lightning.fabric import Fabric
class MyModel(nn.Module):
class PyTorchModel(nn.Module):
...
class MyDataset(Dataset):
class PyTorchDataset(Dataset):
...
+ fabric = Fabric(accelerator="cuda", devices=8, strategy="ddp")
+ fabric.launch()
- device = "cuda" if torch.cuda.is_available() else "cpu
model = MyModel(...)
model = PyTorchModel(...)
optimizer = torch.optim.SGD(model.parameters())
+ model, optimizer = fabric.setup(model, optimizer)
dataloader = DataLoader(MyDataset(...), ...)
dataloader = DataLoader(PyTorchDataset(...), ...)
+ dataloader = fabric.setup_dataloaders(dataloader)
model.train()
for epoch in range(num_epochs):
for batch in dataloader:
- batch.to(device)
input, target = batch
- input, target = input.to(device), target.to(device)
optimizer.zero_grad()
loss = model(batch)
output = model(input)
loss = loss_fn(output, target)
- loss.backward()
+ fabric.backward(loss)
optimizer.step()

View File

@ -39,7 +39,7 @@ Here are five easy steps to let :class:`~lightning_fabric.fabric.Fabric` scale y
.. code-block:: bash
lightning run model path/to/train.py``
lightning run model path/to/train.py
or use the :meth:`~lightning_fabric.fabric.Fabric.launch` method in a notebook.
Learn more about :doc:`launching distributed training <launch>`.
@ -56,28 +56,30 @@ All steps combined, this is how your code will change:
+ from lightning.fabric import Fabric
class MyModel(nn.Module):
class PyTorchModel(nn.Module):
...
class MyDataset(Dataset):
class PyTorchDataset(Dataset):
...
+ fabric = Fabric(accelerator="cuda", devices=8, strategy="ddp")
+ fabric.launch()
- device = "cuda" if torch.cuda.is_available() else "cpu
model = MyModel(...)
model = PyTorchModel(...)
optimizer = torch.optim.SGD(model.parameters())
+ model, optimizer = fabric.setup(model, optimizer)
dataloader = DataLoader(MyDataset(...), ...)
dataloader = DataLoader(PyTorchDataset(...), ...)
+ dataloader = fabric.setup_dataloaders(dataloader)
model.train()
for epoch in range(num_epochs):
for batch in dataloader:
- batch.to(device)
input, target = batch
- input, target = input.to(device), target.to(device)
optimizer.zero_grad()
loss = model(batch)
output = model(input)
loss = loss_fn(output, target)
- loss.backward()
+ fabric.backward(loss)
optimizer.step()
@ -85,3 +87,41 @@ All steps combined, this is how your code will change:
That's it! You can now train on any device at any scale with a switch of a flag.
Check out our before-and-after example for `image classification <https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/image_classifier/README.md>`_ and many more :ref:`examples <Fabric Examples>` that use Fabric.
**********
Next steps
**********
.. raw:: html
<div class="display-card-container">
<div class="row">
.. displayitem::
:header: Examples
:description: See examples across computer vision, NLP, RL, etc.
:col_css: col-md-4
:button_link: ../fabric.html#examples
:height: 150
:tag: basic
.. displayitem::
:header: Accelerators
:description: Take advantage of your hardware with a switch of a flag
:button_link: accelerators.html
:col_css: col-md-4
:height: 150
:tag: intermediate
.. displayitem::
:header: Build your own Trainer
:description: Learn how to build a trainer tailored for you
:col_css: col-md-4
:button_link: ../fabric.html#build-your-own-trainer
:height: 150
:tag: intermediate
.. raw:: html
</div>
</div>