Address feedback for new Lite docs (#16330)
This commit is contained in:
parent
0876a6412d
commit
426c463721
|
@ -24,28 +24,30 @@ With only a few changes to your code, Fabric allows you to:
|
|||
|
||||
+ from lightning.fabric import Fabric
|
||||
|
||||
class MyModel(nn.Module):
|
||||
class PyTorchModel(nn.Module):
|
||||
...
|
||||
|
||||
class MyDataset(Dataset):
|
||||
class PyTorchDataset(Dataset):
|
||||
...
|
||||
|
||||
+ fabric = Fabric(accelerator="cuda", devices=8, strategy="ddp")
|
||||
+ fabric.launch()
|
||||
|
||||
- device = "cuda" if torch.cuda.is_available() else "cpu
|
||||
model = MyModel(...)
|
||||
model = PyTorchModel(...)
|
||||
optimizer = torch.optim.SGD(model.parameters())
|
||||
+ model, optimizer = fabric.setup(model, optimizer)
|
||||
dataloader = DataLoader(MyDataset(...), ...)
|
||||
dataloader = DataLoader(PyTorchDataset(...), ...)
|
||||
+ dataloader = fabric.setup_dataloaders(dataloader)
|
||||
model.train()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
for batch in dataloader:
|
||||
- batch.to(device)
|
||||
input, target = batch
|
||||
- input, target = input.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
output = model(input)
|
||||
loss = loss_fn(output, target)
|
||||
- loss.backward()
|
||||
+ fabric.backward(loss)
|
||||
optimizer.step()
|
||||
|
|
|
@ -39,7 +39,7 @@ Here are five easy steps to let :class:`~lightning_fabric.fabric.Fabric` scale y
|
|||
|
||||
.. code-block:: bash
|
||||
|
||||
lightning run model path/to/train.py``
|
||||
lightning run model path/to/train.py
|
||||
|
||||
or use the :meth:`~lightning_fabric.fabric.Fabric.launch` method in a notebook.
|
||||
Learn more about :doc:`launching distributed training <launch>`.
|
||||
|
@ -56,28 +56,30 @@ All steps combined, this is how your code will change:
|
|||
|
||||
+ from lightning.fabric import Fabric
|
||||
|
||||
class MyModel(nn.Module):
|
||||
class PyTorchModel(nn.Module):
|
||||
...
|
||||
|
||||
class MyDataset(Dataset):
|
||||
class PyTorchDataset(Dataset):
|
||||
...
|
||||
|
||||
+ fabric = Fabric(accelerator="cuda", devices=8, strategy="ddp")
|
||||
+ fabric.launch()
|
||||
|
||||
- device = "cuda" if torch.cuda.is_available() else "cpu
|
||||
model = MyModel(...)
|
||||
model = PyTorchModel(...)
|
||||
optimizer = torch.optim.SGD(model.parameters())
|
||||
+ model, optimizer = fabric.setup(model, optimizer)
|
||||
dataloader = DataLoader(MyDataset(...), ...)
|
||||
dataloader = DataLoader(PyTorchDataset(...), ...)
|
||||
+ dataloader = fabric.setup_dataloaders(dataloader)
|
||||
model.train()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
for batch in dataloader:
|
||||
- batch.to(device)
|
||||
input, target = batch
|
||||
- input, target = input.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
output = model(input)
|
||||
loss = loss_fn(output, target)
|
||||
- loss.backward()
|
||||
+ fabric.backward(loss)
|
||||
optimizer.step()
|
||||
|
@ -85,3 +87,41 @@ All steps combined, this is how your code will change:
|
|||
|
||||
That's it! You can now train on any device at any scale with a switch of a flag.
|
||||
Check out our before-and-after example for `image classification <https://github.com/Lightning-AI/lightning/blob/master/examples/fabric/image_classifier/README.md>`_ and many more :ref:`examples <Fabric Examples>` that use Fabric.
|
||||
|
||||
**********
|
||||
Next steps
|
||||
**********
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<div class="display-card-container">
|
||||
<div class="row">
|
||||
|
||||
.. displayitem::
|
||||
:header: Examples
|
||||
:description: See examples across computer vision, NLP, RL, etc.
|
||||
:col_css: col-md-4
|
||||
:button_link: ../fabric.html#examples
|
||||
:height: 150
|
||||
:tag: basic
|
||||
|
||||
.. displayitem::
|
||||
:header: Accelerators
|
||||
:description: Take advantage of your hardware with a switch of a flag
|
||||
:button_link: accelerators.html
|
||||
:col_css: col-md-4
|
||||
:height: 150
|
||||
:tag: intermediate
|
||||
|
||||
.. displayitem::
|
||||
:header: Build your own Trainer
|
||||
:description: Learn how to build a trainer tailored for you
|
||||
:col_css: col-md-4
|
||||
:button_link: ../fabric.html#build-your-own-trainer
|
||||
:height: 150
|
||||
:tag: intermediate
|
||||
|
||||
.. raw:: html
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
|
Loading…
Reference in New Issue