From 426c463721fcd9ff7e08adfa61b1e4dd72b19455 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 12 Jan 2023 13:08:32 +0100 Subject: [PATCH] Address feedback for new Lite docs (#16330) --- docs/source-pytorch/fabric/fabric.rst | 14 ++--- .../fabric/fundamentals/convert.rst | 54 ++++++++++++++++--- 2 files changed, 55 insertions(+), 13 deletions(-) diff --git a/docs/source-pytorch/fabric/fabric.rst b/docs/source-pytorch/fabric/fabric.rst index aeb680e495..779887439f 100644 --- a/docs/source-pytorch/fabric/fabric.rst +++ b/docs/source-pytorch/fabric/fabric.rst @@ -24,28 +24,30 @@ With only a few changes to your code, Fabric allows you to: + from lightning.fabric import Fabric - class MyModel(nn.Module): + class PyTorchModel(nn.Module): ... - class MyDataset(Dataset): + class PyTorchDataset(Dataset): ... + fabric = Fabric(accelerator="cuda", devices=8, strategy="ddp") + fabric.launch() - device = "cuda" if torch.cuda.is_available() else "cpu - model = MyModel(...) + model = PyTorchModel(...) optimizer = torch.optim.SGD(model.parameters()) + model, optimizer = fabric.setup(model, optimizer) - dataloader = DataLoader(MyDataset(...), ...) + dataloader = DataLoader(PyTorchDataset(...), ...) + dataloader = fabric.setup_dataloaders(dataloader) model.train() for epoch in range(num_epochs): for batch in dataloader: - - batch.to(device) + input, target = batch + - input, target = input.to(device), target.to(device) optimizer.zero_grad() - loss = model(batch) + output = model(input) + loss = loss_fn(output, target) - loss.backward() + fabric.backward(loss) optimizer.step() diff --git a/docs/source-pytorch/fabric/fundamentals/convert.rst b/docs/source-pytorch/fabric/fundamentals/convert.rst index 1e401ba717..d823bf2ec9 100644 --- a/docs/source-pytorch/fabric/fundamentals/convert.rst +++ b/docs/source-pytorch/fabric/fundamentals/convert.rst @@ -39,7 +39,7 @@ Here are five easy steps to let :class:`~lightning_fabric.fabric.Fabric` scale y .. code-block:: bash - lightning run model path/to/train.py`` + lightning run model path/to/train.py or use the :meth:`~lightning_fabric.fabric.Fabric.launch` method in a notebook. Learn more about :doc:`launching distributed training `. @@ -56,28 +56,30 @@ All steps combined, this is how your code will change: + from lightning.fabric import Fabric - class MyModel(nn.Module): + class PyTorchModel(nn.Module): ... - class MyDataset(Dataset): + class PyTorchDataset(Dataset): ... + fabric = Fabric(accelerator="cuda", devices=8, strategy="ddp") + fabric.launch() - device = "cuda" if torch.cuda.is_available() else "cpu - model = MyModel(...) + model = PyTorchModel(...) optimizer = torch.optim.SGD(model.parameters()) + model, optimizer = fabric.setup(model, optimizer) - dataloader = DataLoader(MyDataset(...), ...) + dataloader = DataLoader(PyTorchDataset(...), ...) + dataloader = fabric.setup_dataloaders(dataloader) model.train() for epoch in range(num_epochs): for batch in dataloader: - - batch.to(device) + input, target = batch + - input, target = input.to(device), target.to(device) optimizer.zero_grad() - loss = model(batch) + output = model(input) + loss = loss_fn(output, target) - loss.backward() + fabric.backward(loss) optimizer.step() @@ -85,3 +87,41 @@ All steps combined, this is how your code will change: That's it! You can now train on any device at any scale with a switch of a flag. Check out our before-and-after example for `image classification `_ and many more :ref:`examples ` that use Fabric. + +********** +Next steps +********** + +.. raw:: html + +
+
+ +.. displayitem:: + :header: Examples + :description: See examples across computer vision, NLP, RL, etc. + :col_css: col-md-4 + :button_link: ../fabric.html#examples + :height: 150 + :tag: basic + +.. displayitem:: + :header: Accelerators + :description: Take advantage of your hardware with a switch of a flag + :button_link: accelerators.html + :col_css: col-md-4 + :height: 150 + :tag: intermediate + +.. displayitem:: + :header: Build your own Trainer + :description: Learn how to build a trainer tailored for you + :col_css: col-md-4 + :button_link: ../fabric.html#build-your-own-trainer + :height: 150 + :tag: intermediate + +.. raw:: html + +
+