2020-08-20 01:22:39 +00:00
.. testsetup :: *
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
2020-11-30 13:29:49 +00:00
import pytorch_lightning as pl
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.trainer import Trainer
2020-08-20 01:22:39 +00:00
2020-09-14 01:04:21 +00:00
.. _new_project:
2020-08-20 01:22:39 +00:00
2022-03-27 14:34:32 +00:00
####################
Lightning in 2 Steps
####################
2020-08-20 01:22:39 +00:00
2022-03-27 14:34:32 +00:00
**In this guide we'll show you how to organize your PyTorch code into Lightning in 2 steps.**
2022-01-13 21:11:43 +00:00
2022-03-27 14:34:32 +00:00
Organizing your code with PyTorch Lightning makes your code:
2022-01-13 21:11:43 +00:00
2022-03-27 14:34:32 +00:00
* Keep all the flexibility (this is all pure PyTorch), but removes a ton of boilerplate
2020-08-20 01:22:39 +00:00
* More readable by decoupling the research code from the engineering
* Easier to reproduce
2020-11-30 13:29:49 +00:00
* Less error-prone by automating most of the training loop and tricky engineering
2020-08-20 01:22:39 +00:00
* Scalable to any hardware without changing your model
----------
2022-03-27 14:34:32 +00:00
Here's a 3 minute conversion guide for PyTorch projects:
2020-08-20 01:22:39 +00:00
2022-03-27 14:34:32 +00:00
.. raw :: html
2020-08-20 01:22:39 +00:00
2022-03-27 14:34:32 +00:00
<video width="100%" max-width="800px" controls autoplay muted playsinline
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_docs_animation_final.m4v"></video>
2020-08-20 01:22:39 +00:00
2022-03-27 14:34:32 +00:00
----------
2020-08-20 01:22:39 +00:00
2022-03-27 14:34:32 +00:00
***** ***** ***** ***** ***** ***** ***
Step 0: Install PyTorch Lightning
***** ***** ***** ***** ***** ***** ***
2020-08-30 13:31:36 +00:00
2020-08-20 01:22:39 +00:00
2022-03-27 14:34:32 +00:00
You can install using `pip <https://pypi.org/project/pytorch-lightning/> `_
2020-08-20 01:22:39 +00:00
2022-03-27 14:34:32 +00:00
.. code-block :: bash
2020-08-20 01:22:39 +00:00
2022-03-27 14:34:32 +00:00
pip install pytorch-lightning
2020-08-20 01:22:39 +00:00
2022-03-27 14:34:32 +00:00
Or with `conda <https://anaconda.org/conda-forge/pytorch-lightning> `_ (see how to install conda `here <https://docs.conda.io/projects/conda/en/latest/user-guide/install/> `_ ):
2020-08-20 01:22:39 +00:00
2022-03-27 14:34:32 +00:00
.. code-block :: bash
2020-08-20 01:22:39 +00:00
2022-03-27 14:34:32 +00:00
conda install pytorch-lightning -c conda-forge
2020-08-20 01:22:39 +00:00
2022-03-27 14:34:32 +00:00
You could also use conda environments
2022-01-13 21:11:43 +00:00
2022-03-27 14:34:32 +00:00
.. code-block :: bash
2022-01-13 21:11:43 +00:00
2022-03-27 14:34:32 +00:00
conda activate my_env
pip install pytorch-lightning
2022-01-13 21:11:43 +00:00
2022-03-27 14:34:32 +00:00
----------
2022-01-13 21:11:43 +00:00
2020-09-21 15:17:59 +00:00
Import the following:
2020-08-20 01:22:39 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
:skipif: not _TORCHVISION_AVAILABLE
2020-08-20 01:22:39 +00:00
import os
import torch
2020-09-21 15:17:59 +00:00
from torch import nn
2020-08-20 01:22:39 +00:00
import torch.nn.functional as F
from torchvision import transforms
2020-11-30 13:29:49 +00:00
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
2020-08-20 01:22:39 +00:00
import pytorch_lightning as pl
2022-01-13 21:11:43 +00:00
2020-09-21 15:17:59 +00:00
Step 1: Define LightningModule
2022-01-13 21:11:43 +00:00
==============================
2020-09-21 15:17:59 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-09-21 15:17:59 +00:00
2021-03-23 23:08:57 +00:00
class LitAutoEncoder(pl.LightningModule):
2020-08-20 01:22:39 +00:00
def __init__(self):
super().__init__()
2021-07-30 12:10:15 +00:00
self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
2020-08-20 01:22:39 +00:00
2020-09-22 18:00:02 +00:00
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
2020-09-21 15:17:59 +00:00
def training_step(self, batch, batch_idx):
2020-10-12 15:56:16 +00:00
# training_step defined the train loop.
# It is independent of forward
2020-09-21 15:17:59 +00:00
x, y = batch
2020-08-20 01:22:39 +00:00
x = x.view(x.size(0), -1)
2020-09-21 15:17:59 +00:00
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
2020-09-30 03:44:27 +00:00
# Logging to TensorBoard by default
2021-07-30 12:10:15 +00:00
self.log("train_loss", loss)
2020-09-21 15:17:59 +00:00
return loss
2020-08-30 13:31:36 +00:00
2020-08-20 01:22:39 +00:00
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
2020-08-30 13:31:36 +00:00
2022-01-13 21:11:43 +00:00
2020-10-14 17:43:58 +00:00
**SYSTEM VS MODEL**
2020-10-12 15:56:16 +00:00
2022-01-13 21:11:43 +00:00
A :doc: `lightning module <../common/lightning_module>` defines a *system* not just a model.
2020-10-12 15:13:26 +00:00
2020-10-12 15:56:16 +00:00
.. figure :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/model_system.png
:width: 400
2020-10-12 15:13:26 +00:00
2020-10-12 15:56:16 +00:00
Examples of systems are:
2020-08-30 10:51:34 +00:00
2022-01-13 21:11:43 +00:00
- `Autoencoder <https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/autoencoder.py> `_
2021-08-23 19:25:21 +00:00
- `BERT <https://colab.research.google.com/github/PyTorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/text-transformers.ipynb> `_
- `DQN <https://colab.research.google.com/github/PyTorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/reinforce-learning-DQN.ipynb> `_
- `GAN <https://colab.research.google.com/github/PyTorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/basic-gan.ipynb> `_
- `Image classifier <https://colab.research.google.com/github/PyTorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/mnist-hello-world.ipynb> `_
2022-01-13 21:11:43 +00:00
- `Semantic Segmentation <https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/semantic_segmentation.py> `_
2021-08-23 19:25:21 +00:00
- `and a lot more <https://github.com/PyTorchLightning/lightning-tutorials/tree/publication/.notebooks/lightning_examples> `_
2020-08-20 01:22:39 +00:00
2022-02-21 21:21:12 +00:00
Under the hood, a LightningModule is still just a :class: `torch.nn.Module` that groups all research code into a single file to make it self-contained:
2020-08-20 01:22:39 +00:00
- The Train loop
- The Validation loop
- The Test loop
2021-10-28 12:31:02 +00:00
- The Prediction loop
2020-10-14 17:43:58 +00:00
- The Model or system of Models
2022-01-13 21:11:43 +00:00
- The Optimizers and LR Schedulers
2020-08-20 01:22:39 +00:00
2020-09-21 15:17:59 +00:00
You can customize any part of training (such as the backward pass) by overriding any
2022-01-13 21:11:43 +00:00
of the 20+ hooks found in :ref: `lightning_hooks`
2020-08-20 01:22:39 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-08-20 01:22:39 +00:00
2022-01-13 21:11:43 +00:00
class LitAutoEncoder(pl.LightningModule):
2020-10-11 02:04:50 +00:00
def backward(self, loss, optimizer, optimizer_idx):
2020-09-21 15:17:59 +00:00
loss.backward()
2020-08-30 13:31:36 +00:00
2020-10-12 20:48:07 +00:00
**FORWARD vs TRAINING_STEP**
2022-02-21 22:08:26 +00:00
In Lightning we suggest separating training from inference. The `` training_step `` defines
2022-01-13 21:11:43 +00:00
the full training loop. We encourage users to use the `` forward `` to define inference actions.
2020-10-12 20:48:07 +00:00
2022-01-13 21:11:43 +00:00
For example, in this case we can define the autoencoder to act as an embedding extractor:
2020-09-23 11:36:51 +00:00
.. code-block :: python
2020-08-20 01:22:39 +00:00
2022-01-13 21:11:43 +00:00
def forward(self, batch):
embeddings = self.encoder(batch)
2020-10-12 20:48:07 +00:00
return embeddings
2020-09-23 11:36:51 +00:00
2022-02-21 21:21:12 +00:00
Of course, nothing is preventing you from using `` forward `` from within the `` training_step `` .
2020-10-12 20:48:07 +00:00
.. code-block :: python
def training_step(self, batch, batch_idx):
...
2022-02-03 03:34:27 +00:00
embeddings = self.encoder(batch)
2022-01-13 21:11:43 +00:00
output = self.decoder(embeddings)
2020-10-12 20:48:07 +00:00
2020-11-30 13:29:49 +00:00
It really comes down to your application. We do, however, recommend that you keep both intents separate.
2020-10-12 20:48:07 +00:00
2022-01-13 21:11:43 +00:00
* Use `` forward `` for inference (predicting).
* Use `` training_step `` for training.
2020-09-23 11:36:51 +00:00
2022-01-13 21:11:43 +00:00
More details in :doc: `LightningModule <../common/lightning_module>` docs.
2020-09-22 18:00:02 +00:00
2020-08-30 13:31:36 +00:00
----------
2020-09-22 10:00:54 +00:00
Step 2: Fit with Lightning Trainer
2022-01-13 21:11:43 +00:00
==================================
2020-08-30 13:31:36 +00:00
2021-10-28 12:31:02 +00:00
First, define the data however you want. Lightning just needs a :class: `~torch.utils.data.DataLoader` for the train/val/test/predict splits.
2020-08-30 15:01:16 +00:00
.. code-block :: python
2020-08-30 13:31:36 +00:00
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
2020-08-31 14:48:50 +00:00
train_loader = DataLoader(dataset)
2021-01-07 05:24:47 +00:00
2022-01-13 21:11:43 +00:00
Next, init the :doc: `LightningModule <../common/lightning_module>` and the PyTorch Lightning :doc: `Trainer <../common/trainer>` ,
2020-09-22 10:00:54 +00:00
then call fit with both the data and model.
2020-08-30 15:01:16 +00:00
.. code-block :: python
2020-08-30 13:31:36 +00:00
# init model
2020-09-22 10:00:54 +00:00
autoencoder = LitAutoEncoder()
2020-08-30 13:31:36 +00:00
# most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more)
2022-03-23 19:52:12 +00:00
# trainer = pl.Trainer(accelerator="gpu", devices=8) (if you have GPUs)
2020-08-30 13:31:36 +00:00
trainer = pl.Trainer()
2022-01-13 21:11:43 +00:00
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
2020-09-22 10:00:54 +00:00
The :class: `~pytorch_lightning.trainer.Trainer` automates:
* Epoch and batch iteration
2022-01-13 21:11:43 +00:00
* `` optimizer.step() `` , `` loss.backward() `` , `` optimizer.zero_grad() `` calls
* Calling of `` model.eval() `` , enabling/disabling grads during evaluation
* :doc: `Checkpoint Saving and Loading <../common/checkpointing>`
2021-01-26 20:07:07 +00:00
* Tensorboard (see :doc: `loggers <../common/loggers>` options)
2022-01-06 13:42:44 +00:00
* :ref: `Multi-GPU <accelerators/gpu:Multi GPU Training>` support
* :doc: `TPU <../accelerators/tpu>`
2021-06-16 21:28:51 +00:00
* :ref: `16-bit precision AMP <amp>` support
2020-09-22 10:00:54 +00:00
2022-02-21 21:21:12 +00:00
.. tip :: If you prefer to manually manage optimizers, you can use the :ref: `manual_opt` mode (i.e., RL, GANs, and so on).
2020-10-11 17:12:19 +00:00
**That's it!**
2022-02-21 21:21:12 +00:00
These are the main two components you need to know in Lightning in general. All the other features of Lightning are either
2022-01-13 21:11:43 +00:00
features of the Trainer or LightningModule or are extensions for advanced use-cases.
2020-10-11 17:12:19 +00:00
2020-09-22 10:00:54 +00:00
-----------
2020-08-30 13:31:36 +00:00
2020-10-11 17:12:19 +00:00
***** ***** *** *
2022-01-13 21:11:43 +00:00
Basic Features
2020-10-11 17:12:19 +00:00
***** ***** *** *
2022-01-13 21:11:43 +00:00
Manual vs Automatic Optimization
2020-10-12 20:48:07 +00:00
================================
2022-01-13 21:11:43 +00:00
Automatic Optimization
2020-10-12 20:48:07 +00:00
----------------------
2022-01-13 21:11:43 +00:00
2020-11-30 13:29:49 +00:00
With Lightning, you don't need to worry about when to enable/disable grads, do a backward pass, or update optimizers
2022-01-13 21:11:43 +00:00
as long as you return a loss with an attached graph from the :meth: `~pytorch_lightning.core.lightning.LightningModule.training_step` method,
Lightning will automate the optimization.
2020-10-12 20:48:07 +00:00
.. code-block :: python
def training_step(self, batch, batch_idx):
2021-03-01 13:36:46 +00:00
loss = self.encoder(batch)
2020-10-12 20:48:07 +00:00
return loss
.. _manual_opt:
2022-01-13 21:11:43 +00:00
Manual Optimization
2020-10-12 20:48:07 +00:00
-------------------
2022-01-13 21:11:43 +00:00
2022-02-21 21:21:12 +00:00
For certain research like GANs, reinforcement learning, or something with multiple optimizers
2022-01-13 21:11:43 +00:00
or an inner loop, you can turn off automatic optimization and fully control it yourself.
2020-10-12 20:48:07 +00:00
2022-02-21 21:21:12 +00:00
Turn off automatic optimization, and you control the optimization!
2020-10-12 20:48:07 +00:00
.. code-block :: python
2021-02-18 19:51:56 +00:00
def __init__(self):
self.automatic_optimization = False
2021-07-30 12:10:15 +00:00
2021-03-07 07:48:50 +00:00
def training_step(self, batch, batch_idx):
2021-10-28 12:31:02 +00:00
# access your optimizers with use_pl_optimizer=False. Default is True,
# setting use_pl_optimizer=True will maintain plugin/precision support
2021-03-01 13:36:46 +00:00
opt_a, opt_b = self.optimizers(use_pl_optimizer=True)
2020-10-12 20:48:07 +00:00
2021-03-01 13:36:46 +00:00
loss_a = self.generator(batch)
2020-10-12 20:48:07 +00:00
opt_a.zero_grad()
2021-03-01 13:36:46 +00:00
# use `manual_backward()` instead of `loss.backward` to automate half precision, etc...
self.manual_backward(loss_a)
opt_a.step()
2020-10-12 20:48:07 +00:00
2021-03-01 13:36:46 +00:00
loss_b = self.discriminator(batch)
opt_b.zero_grad()
self.manual_backward(loss_b)
opt_b.step()
2020-10-12 20:48:07 +00:00
2022-01-13 21:11:43 +00:00
Loop Customization
2021-10-18 09:43:11 +00:00
==================
2022-01-13 21:11:43 +00:00
If you need even more flexibility, you can fully customize the training loop to its core. These are usually required to be customized
for advanced use-cases. Learn more inside :doc: `Loops docs <../extensions/loops>` .
2020-10-11 17:12:19 +00:00
2020-09-21 15:17:59 +00:00
Predict or Deploy
2020-10-11 17:12:19 +00:00
=================
2022-01-13 21:11:43 +00:00
2022-02-21 21:21:12 +00:00
When you're done training, you have three options to use your LightningModule for predictions.
2020-09-21 15:17:59 +00:00
2020-09-22 10:00:54 +00:00
Option 1: Sub-models
2020-10-11 17:12:19 +00:00
--------------------
2022-01-13 21:11:43 +00:00
2020-09-22 10:00:54 +00:00
Pull out any model inside your system for predictions.
2020-09-21 15:17:59 +00:00
.. code-block :: python
# ----------------------------------
# to use as embedding extractor
# ----------------------------------
2021-07-30 12:10:15 +00:00
autoencoder = LitAutoEncoder.load_from_checkpoint("path/to/checkpoint_file.ckpt")
2020-09-22 10:00:54 +00:00
encoder_model = autoencoder.encoder
encoder_model.eval()
2020-09-21 15:17:59 +00:00
# ----------------------------------
# to use as image generator
# ----------------------------------
2020-09-22 10:00:54 +00:00
decoder_model = autoencoder.decoder
decoder_model.eval()
2020-09-21 15:17:59 +00:00
2022-01-13 21:11:43 +00:00
2020-09-22 10:00:54 +00:00
Option 2: Forward
2020-10-11 17:12:19 +00:00
-----------------
2022-01-13 21:11:43 +00:00
2020-09-22 10:00:54 +00:00
You can also add a forward method to do predictions however you want.
2020-09-21 15:17:59 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-09-21 15:17:59 +00:00
# ----------------------------------
# using the AE to extract embeddings
# ----------------------------------
2021-01-26 09:44:54 +00:00
class LitAutoEncoder(LightningModule):
def __init__(self):
super().__init__()
2022-01-13 21:11:43 +00:00
self.encoder = nn.Sequential(nn.Linear(28 * 28, 64))
2021-01-26 09:44:54 +00:00
2020-09-21 15:17:59 +00:00
def forward(self, x):
embedding = self.encoder(x)
2020-09-22 10:00:54 +00:00
return embedding
2020-09-21 15:17:59 +00:00
2021-07-30 12:10:15 +00:00
2021-01-26 09:44:54 +00:00
autoencoder = LitAutoEncoder()
2021-10-28 12:31:02 +00:00
embedding = autoencoder(torch.rand(1, 28 * 28))
2020-11-30 13:29:49 +00:00
2020-09-22 10:00:54 +00:00
.. code-block :: python
2020-09-21 15:17:59 +00:00
2022-02-21 21:21:12 +00:00
# -------------------------------
# using the AE to generate images
# -------------------------------
2021-01-26 09:44:54 +00:00
class LitAutoEncoder(LightningModule):
def __init__(self):
super().__init__()
2022-01-13 21:11:43 +00:00
self.decoder = nn.Sequential(nn.Linear(64, 28 * 28))
2021-01-26 09:44:54 +00:00
2020-09-21 15:17:59 +00:00
def forward(self):
2022-01-13 21:11:43 +00:00
z = torch.rand(1, 64)
2020-09-21 15:17:59 +00:00
image = self.decoder(z)
image = image.view(1, 1, 28, 28)
return image
2021-07-30 12:10:15 +00:00
2021-01-26 09:44:54 +00:00
autoencoder = LitAutoEncoder()
2020-11-30 11:00:14 +00:00
image_sample = autoencoder()
2020-09-21 15:17:59 +00:00
2022-01-13 21:11:43 +00:00
2020-09-22 10:00:54 +00:00
Option 3: Production
2020-10-11 17:12:19 +00:00
--------------------
2022-01-13 21:11:43 +00:00
2022-02-21 21:21:12 +00:00
For production systems, `ONNX <https://pytorch.org/docs/stable/onnx.html> `_ or `TorchScript <https://pytorch.org/docs/stable/jit.html> `_ is much faster.
2022-01-13 21:11:43 +00:00
Make sure you have added a `` forward `` method or trace only the sub-models you need.
* TorchScript using :meth: `~pytorch_lightning.core.lightning.LightningModule.to_torchscript` method.
2020-09-21 15:17:59 +00:00
.. code-block :: python
2020-09-22 10:00:54 +00:00
autoencoder = LitAutoEncoder()
2022-01-13 21:11:43 +00:00
autoencoder.to_torchscript(file_path="model.pt")
* Onnx using :meth: `~pytorch_lightning.core.lightning.LightningModule.to_onnx` method.
2020-09-21 15:17:59 +00:00
2020-09-22 10:00:54 +00:00
.. code-block :: python
2022-01-13 21:11:43 +00:00
autoencoder = LitAutoEncoder()
input_sample = torch.randn((1, 28 * 28))
autoencoder.to_onnx(file_path="model.onnx", input_sample=input_sample, export_params=True)
2020-09-21 15:17:59 +00:00
2020-09-22 10:00:54 +00:00
2022-01-13 21:11:43 +00:00
Using Accelerators
==================
2022-02-21 21:21:12 +00:00
It's easy to use CPUs, GPUs, TPUs or IPUs in Lightning. There's **no need** to change your code; simply change the :class: `~pytorch_lightning.trainer.trainer.Trainer` options.
2022-01-13 21:11:43 +00:00
CPU
---
2020-09-22 10:00:54 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-09-22 10:00:54 +00:00
# train on CPU
2021-01-26 09:44:54 +00:00
trainer = Trainer()
2020-09-22 10:00:54 +00:00
# train on 8 CPUs
2022-03-28 14:44:59 +00:00
trainer = Trainer(accelerator="cpu", devices=8)
2020-09-22 10:00:54 +00:00
# train on 1024 CPUs across 128 machines
2022-03-28 14:44:59 +00:00
trainer = pl.Trainer(accelerator="cpu", devices=8, num_nodes=128)
2020-09-22 10:00:54 +00:00
2022-01-13 21:11:43 +00:00
GPU
---
2020-09-22 10:00:54 +00:00
.. code-block :: python
# train on 1 GPU
2022-03-23 19:52:12 +00:00
trainer = pl.Trainer(accelerator="gpu", devices=1)
2021-01-07 05:24:47 +00:00
2022-03-28 14:44:59 +00:00
# train on multiple GPUs across nodes (32 GPUs here)
2022-03-23 19:52:12 +00:00
trainer = pl.Trainer(accelerator="gpu", devices=4, num_nodes=8)
2021-01-07 05:24:47 +00:00
2022-03-28 14:44:59 +00:00
# train on gpu 1, 3, 5 (3 GPUs total)
2022-03-23 19:52:12 +00:00
trainer = pl.Trainer(accelerator="gpu", devices=[1, 3, 5])
2020-09-22 10:00:54 +00:00
# Multi GPU with mixed precision
2022-03-23 19:52:12 +00:00
trainer = pl.Trainer(accelerator="gpu", devices=2, precision=16)
2020-09-22 10:00:54 +00:00
2022-01-13 21:11:43 +00:00
TPU
---
2020-09-22 10:00:54 +00:00
.. code-block :: python
2022-01-13 21:11:43 +00:00
# Train on 8 TPU cores
2022-03-25 11:57:02 +00:00
trainer = pl.Trainer(accelerator="tpu", devices=8)
2020-09-22 10:00:54 +00:00
2022-01-13 21:11:43 +00:00
# Train on single TPU core
2022-03-25 11:57:02 +00:00
trainer = pl.Trainer(accelerator="tpu", devices=1)
2020-09-22 10:00:54 +00:00
2022-01-13 21:11:43 +00:00
# Train on 7th TPU core
2022-03-25 11:57:02 +00:00
trainer = pl.Trainer(accelerator="tpu", devices=[7])
2020-09-22 10:00:54 +00:00
2022-01-13 21:11:43 +00:00
# without changing a SINGLE line of your code, you can
2022-02-21 21:21:12 +00:00
# train on TPUs using 16-bit precision
2020-09-22 10:00:54 +00:00
# using only half the training data and checking validation every quarter of a training epoch
2022-03-25 11:57:02 +00:00
trainer = pl.Trainer(accelerator="tpu", devices=8, precision=16, limit_train_batches=0.5, val_check_interval=0.25)
2021-01-07 05:24:47 +00:00
2022-01-13 21:11:43 +00:00
IPU
---
2021-10-28 12:31:02 +00:00
.. code-block :: python
# Train on IPUs
2022-03-28 14:44:59 +00:00
trainer = pl.Trainer(accelerator="ipu", devices=8)
2021-10-28 12:31:02 +00:00
2020-08-30 13:31:36 +00:00
2022-01-13 21:11:43 +00:00
Checkpointing
=============
2020-09-22 10:00:54 +00:00
Lightning automatically saves your model. Once you've trained, you can load the checkpoints as follows:
2020-08-20 01:22:39 +00:00
2020-08-30 15:01:16 +00:00
.. code-block :: python
2020-08-30 10:51:34 +00:00
2022-01-13 21:11:43 +00:00
model = LitModel.load_from_checkpoint(path_to_saved_checkpoint)
2020-08-30 10:51:34 +00:00
2020-09-22 10:00:54 +00:00
The above checkpoint contains all the arguments needed to init the model and set the state dict.
2020-08-30 15:01:16 +00:00
If you prefer to do it manually, here's the equivalent
2020-08-30 10:51:34 +00:00
2020-08-30 15:01:16 +00:00
.. code-block :: python
2020-08-30 10:51:34 +00:00
2020-08-30 15:01:16 +00:00
# load the ckpt
2021-07-30 12:10:15 +00:00
ckpt = torch.load("path/to/checkpoint.ckpt")
2020-08-20 01:22:39 +00:00
2020-08-30 15:01:16 +00:00
# equivalent to the above
model = LitModel()
2021-07-30 12:10:15 +00:00
model.load_state_dict(ckpt["state_dict"])
2020-08-20 01:22:39 +00:00
2022-01-13 21:11:43 +00:00
Learn more inside :ref: `Checkpoint docs <checkpointing>` .
2020-08-30 13:31:36 +00:00
2022-01-13 21:11:43 +00:00
Data Flow
2020-10-11 17:12:19 +00:00
=========
2022-01-13 21:11:43 +00:00
2021-10-28 12:31:02 +00:00
Each loop (training, validation, test, predict) has three hooks you can implement:
2020-10-03 12:15:07 +00:00
2020-09-30 12:31:16 +00:00
- x_step
2022-01-13 21:11:43 +00:00
- x_step_end (optional)
- x_epoch_end (optional)
2020-09-30 03:44:27 +00:00
2022-02-21 21:21:12 +00:00
To illustrate how data flows, we'll use the training loop (i.e., x=training)
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
.. code-block :: python
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
outs = []
for batch in data:
out = training_step(batch)
2022-01-13 21:11:43 +00:00
out = training_step_end(out)
2020-09-30 12:31:16 +00:00
outs.append(out)
training_epoch_end(outs)
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
The equivalent in Lightning is:
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
.. code-block :: python
2020-09-21 15:17:59 +00:00
2020-09-30 12:31:16 +00:00
def training_step(self, batch, batch_idx):
prediction = ...
return prediction
2020-08-20 01:22:39 +00:00
2021-07-30 12:10:15 +00:00
2021-10-28 12:31:02 +00:00
def training_epoch_end(self, outs):
for out in outs:
2021-07-30 12:10:15 +00:00
...
2020-08-20 01:22:39 +00:00
2022-02-21 21:21:12 +00:00
In the event you use DP or DDP2 distributed modes (i.e., split a batch across devices), check out *Training with DataParallel* section :ref: `here <lightning_module>` .
2022-01-13 21:11:43 +00:00
The validation, test and prediction loops have the same structure.
2021-07-30 12:10:15 +00:00
2020-09-22 10:00:54 +00:00
2020-09-30 12:31:16 +00:00
----------------
2022-01-13 21:11:43 +00:00
***** ***** ***** *** *
Optional Extensions
***** ***** ***** *** *
2020-11-07 20:41:02 +00:00
2022-01-13 21:11:43 +00:00
Check out the following optional extensions that can make your ML Pipelines more robust:
2021-07-30 12:10:15 +00:00
2022-01-13 21:11:43 +00:00
* :ref: `LightningDataModule <datamodules>`
* :ref: `Callbacks <callbacks>`
* :ref: `Logging <logging>`
* :ref: `Accelerators <accelerators>`
* :ref: `Plugins <plugins>`
* :ref: `Loops <loop_customization>`
2020-09-22 10:00:54 +00:00
2022-01-13 21:11:43 +00:00
----------------
2020-09-22 10:00:54 +00:00
2022-01-13 21:11:43 +00:00
***** *** *
Debugging
***** *** *
2020-09-22 10:00:54 +00:00
2022-01-13 21:11:43 +00:00
Lightning has many tools for debugging. Here is an example of just a few of them:
2020-08-20 01:22:39 +00:00
2022-01-13 21:11:43 +00:00
Limit Batches
=============
2020-08-20 01:22:39 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-08-20 01:22:39 +00:00
2022-02-21 21:21:12 +00:00
# use only 10 train batches and three val batches per epoch
2022-01-13 21:11:43 +00:00
trainer = Trainer(limit_train_batches=10, limit_val_batches=3)
# use 20% of total train batches and 10% of total val batches per epoch
trainer = Trainer(limit_train_batches=0.2, limit_val_batches=0.1)
2020-08-30 15:01:16 +00:00
2022-01-13 21:11:43 +00:00
Overfit Batches
===============
2020-08-30 13:31:36 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-08-30 13:31:36 +00:00
2022-01-13 21:11:43 +00:00
# Automatically overfit the same batches to your model for a sanity test
# use only 10 train batches
trainer = Trainer(overfit_batches=10)
# use only 20% of total train batches
trainer = Trainer(overfit_batches=0.2)
2020-08-30 15:01:16 +00:00
2022-01-13 21:11:43 +00:00
Fast Dev Run
============
2020-08-30 15:01:16 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-08-30 15:01:16 +00:00
2021-10-28 12:31:02 +00:00
# unit test all the code - hits every line of your code once to see if you have bugs,
2022-01-13 21:11:43 +00:00
# instead of waiting hours to crash somewhere
2021-01-26 09:44:54 +00:00
trainer = Trainer(fast_dev_run=True)
2020-08-20 01:22:39 +00:00
2022-02-21 21:21:12 +00:00
# unit test all the code - hits every line of your code with four batches
2021-10-28 12:31:02 +00:00
trainer = Trainer(fast_dev_run=4)
2022-01-13 21:11:43 +00:00
Val Check Interval
==================
2020-08-20 01:22:39 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-08-20 01:22:39 +00:00
2020-10-08 11:05:26 +00:00
# run validation every 25% of a training epoch
2021-01-26 09:44:54 +00:00
trainer = Trainer(val_check_interval=0.25)
2020-08-20 01:22:39 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2021-01-07 05:24:47 +00:00
2020-09-22 10:00:54 +00:00
# Profile your code to find speed/memory bottlenecks
2021-02-24 09:08:21 +00:00
Trainer(profiler="simple")
2021-01-07 05:24:47 +00:00
2022-01-13 21:11:43 +00:00
2020-08-20 01:22:39 +00:00
---------------
2022-01-13 21:11:43 +00:00
2021-04-30 11:45:17 +00:00
***** ***** ***** *** *
2022-01-13 21:11:43 +00:00
Other Cool Features
2021-04-30 11:45:17 +00:00
***** ***** ***** *** *
2020-08-20 01:22:39 +00:00
2022-02-21 21:21:12 +00:00
Once you define and train your first Lightning model, you might want to try other cool features like:
2020-08-20 01:22:39 +00:00
2021-01-26 20:07:07 +00:00
- :doc: `Automatic early stopping <../common/early_stopping>`
2021-08-11 03:26:01 +00:00
- :ref: `Automatic truncated-back-propagation-through-time <common/lightning_module:truncated_bptt_steps>`
2022-01-12 16:26:03 +00:00
- :ref: `Automatically scale your batch size <advanced/training_tricks:Batch Size Finder>`
- :ref: `Automatically scale your batch size <advanced/training_tricks:Learning Rate Finder>`
2021-12-16 16:21:59 +00:00
- :ref: `Load checkpoints directly from S3 <common/checkpointing:Checkpoint Loading>`
2021-04-21 23:38:16 +00:00
- :doc: `Scale to massive compute clusters <../clouds/cluster>`
2021-10-28 12:31:02 +00:00
- :doc: `Use multiple dataloaders per train/val/test/predict loop <../guides/data>`
2022-03-03 10:19:05 +00:00
- :ref: `Use multiple optimizers to do reinforcement learning or even GANs <common/optimization:Use multiple optimizers (like GANs)>`
2020-09-14 01:04:21 +00:00
2022-03-03 10:19:05 +00:00
Read our :doc: `Guide <../starter/core_guide>` to learn more with a step-by-step walk-through!
2022-01-13 21:11:43 +00:00
2020-08-20 01:22:39 +00:00
-------------
2022-01-13 21:11:43 +00:00
2022-03-27 14:34:32 +00:00
***** ***** ***** **
Starter Templates
***** ***** ***** **
Before installing anything, use the following templates to try it out live:
.. list-table ::
:widths: 18 15 25
:header-rows: 1
* - Use case
- Description
- link
* - Scratch model
- To prototype quickly / debug with random data
-
.. raw :: html
<div style='width:150px;height:auto'>
<a href="https://colab.research.google.com/drive/1rHBxrtopwtF8iLpmC_e7yl3TeDGrseJL?usp=sharing>">
<img alt="open in colab" src="http://bit.ly/pl_colab">
</a>
</div>
* - Scratch model with manual optimization
- To prototype quickly / debug with random data
-
.. raw :: html
<div style='width:150px;height:auto'>
<a href="https://colab.research.google.com/drive/1nGtvBFirIvtNQdppe2xBes6aJnZMjvl8?usp=sharing">
<img alt="open in colab" src="http://bit.ly/pl_colab">
</a>
</div>
------------
2022-01-13 21:11:43 +00:00
***** **
2020-10-11 17:30:25 +00:00
Grid AI
2022-01-13 21:11:43 +00:00
***** **
2021-04-30 11:45:17 +00:00
Grid AI is our native solution for large scale training and tuning on the cloud.
2020-10-11 17:30:25 +00:00
2021-04-30 11:45:17 +00:00
`Get started for free with your GitHub or Google Account here <https://www.grid.ai/> `_ .
2020-10-11 17:30:25 +00:00
2022-01-13 21:11:43 +00:00
2020-10-11 17:30:25 +00:00
------------
2022-01-13 21:11:43 +00:00
***** *** *
2020-09-30 12:31:16 +00:00
Community
2022-01-13 21:11:43 +00:00
***** *** *
2020-10-03 12:15:07 +00:00
Our community of core maintainers and thousands of expert researchers is active on our
2022-03-28 17:27:51 +00:00
`Slack <https://join.slack.com/t/pytorch-lightning/shared_invite/zt-12iz3cds1-uyyyBYJLiaL2bqVmMN7n~A> `_
2021-02-24 18:58:23 +00:00
and `GitHub Discussions <https://github.com/PyTorchLightning/pytorch-lightning/discussions> `_ . Drop by
to hang out, ask Lightning questions or even discuss research!
2020-09-22 10:00:54 +00:00
2020-10-11 17:30:25 +00:00
-------------
2022-01-13 21:11:43 +00:00
***** ***** *
2020-08-20 01:22:39 +00:00
Masterclass
2022-01-13 21:11:43 +00:00
***** ***** *
2020-09-30 12:31:16 +00:00
We also offer a Masterclass to teach you the advanced uses of Lightning.
2020-08-20 01:22:39 +00:00
2021-02-03 15:08:19 +00:00
.. image :: ../_static/images/general/PTL101_youtube_thumbnail.jpg
2020-08-20 01:22:39 +00:00
:width: 500
:align: center
:alt: Masterclass
:target: https://www.youtube.com/playlist?list=PLaMu-SDt_RB5NUm67hU2pdE75j6KaIOv2