2020-08-20 01:22:39 +00:00
.. testsetup :: *
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
2020-11-30 13:29:49 +00:00
import pytorch_lightning as pl
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.trainer import Trainer
2020-08-20 01:22:39 +00:00
2020-09-14 01:04:21 +00:00
.. _new_project:
2020-08-20 01:22:39 +00:00
2022-01-13 21:11:43 +00:00
############
Introduction
############
2020-08-20 01:22:39 +00:00
2022-02-21 21:21:12 +00:00
***** ***** ***** ***** ***** *
What is PyTorch Lightning?
***** ***** ***** ***** ***** *
2020-08-20 01:22:39 +00:00
2022-02-21 21:21:12 +00:00
PyTorch Lightning provides you with the APIs required to build models, datasets, and so on. PyTorch has all you need to train your models; however, there’ s much more to deep learning than attaching layers. When it comes to the actual training, there’ s a lot of boilerplate code that you need to write, and if you need to scale your training/inferencing on multiple devices/machines, there’ s another set of integrations you might need to do.
2022-01-13 21:11:43 +00:00
2022-02-21 21:21:12 +00:00
PyTorch Lightning solves these for you. All you need is to restructure some of your existing code, set certain flags, and then you are done.
Now you can train your models on different accelerators like GPU/TPU/IPU, to do distributed training across multiple machines/nodes without code changes using state-of-the-art distributed training mechanisms.
2022-01-13 21:11:43 +00:00
Code organization is the core of Lightning. It leaves the research logic to you and automates the rest.
----------
***** ***** ***** *****
Lightning Philosophy
***** ***** ***** *****
Organizing your code with Lightning makes your code:
2020-08-20 01:22:39 +00:00
2022-02-21 21:21:12 +00:00
* Flexible (this is all pure PyTorch), but removes a ton of boilerplate
2020-08-20 01:22:39 +00:00
* More readable by decoupling the research code from the engineering
* Easier to reproduce
2020-11-30 13:29:49 +00:00
* Less error-prone by automating most of the training loop and tricky engineering
2020-08-20 01:22:39 +00:00
* Scalable to any hardware without changing your model
2022-01-13 21:11:43 +00:00
Lightning is built for:
2020-08-20 01:22:39 +00:00
2022-02-21 21:21:12 +00:00
* Researchers who want to focus on research without worrying about the engineering aspects of it
* ML Engineers who want to build reproducible pipelines
2022-01-13 21:11:43 +00:00
* Data Scientists who want to try out different models for their tasks and build-in ML techniques
* Educators who seek to study and teach Deep Learning with PyTorch
2020-08-20 01:22:39 +00:00
2022-01-13 21:11:43 +00:00
The team makes sure that all the latest techniques are already integrated and well maintained.
2020-08-20 01:22:39 +00:00
----------
2022-01-13 21:11:43 +00:00
***** ***** ***** **
Starter Templates
***** ***** ***** **
2020-08-20 01:22:39 +00:00
2022-02-21 21:21:12 +00:00
Before installing anything, use the following templates to try it out live:
2020-08-20 01:22:39 +00:00
2022-01-13 21:11:43 +00:00
.. list-table ::
:widths: 18 15 25
:header-rows: 1
2020-08-20 01:22:39 +00:00
2022-01-13 21:11:43 +00:00
* - Use case
- Description
- link
* - Scratch model
- To prototype quickly / debug with random data
-
.. raw :: html
2020-08-30 13:31:36 +00:00
2022-01-13 21:11:43 +00:00
<div style='width:150px;height:auto'>
<a href="https://colab.research.google.com/drive/1rHBxrtopwtF8iLpmC_e7yl3TeDGrseJL?usp=sharing>">
<img alt="open in colab" src="http://bit.ly/pl_colab">
</a>
</div>
* - Scratch model with manual optimization
- To prototype quickly / debug with random data
-
.. raw :: html
2020-08-20 01:22:39 +00:00
2022-01-13 21:11:43 +00:00
<div style='width:150px;height:auto'>
<a href="https://colab.research.google.com/drive/1nGtvBFirIvtNQdppe2xBes6aJnZMjvl8?usp=sharing">
<img alt="open in colab" src="http://bit.ly/pl_colab">
</a>
</div>
2020-08-20 01:22:39 +00:00
2022-01-13 21:11:43 +00:00
----------
2020-08-20 01:22:39 +00:00
2022-01-13 21:11:43 +00:00
***** ***** **
Installation
***** ***** **
2020-08-20 01:22:39 +00:00
2022-01-13 21:11:43 +00:00
Follow the :ref: `Installation Guide <installation>` to install PyTorch Lightning.
2020-08-20 01:22:39 +00:00
----------
2022-01-13 21:11:43 +00:00
***** ***** ***** *****
Lightning Components
***** ***** ***** *****
2022-02-21 21:21:12 +00:00
Here's a 3-minute conversion guide for PyTorch projects:
2022-01-13 21:11:43 +00:00
.. raw :: html
<video width="100%" max-width="800px" controls autoplay muted playsinline
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_docs_animation_final.m4v"></video>
2020-09-21 15:17:59 +00:00
Import the following:
2020-08-20 01:22:39 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
:skipif: not _TORCHVISION_AVAILABLE
2020-08-20 01:22:39 +00:00
import os
import torch
2020-09-21 15:17:59 +00:00
from torch import nn
2020-08-20 01:22:39 +00:00
import torch.nn.functional as F
from torchvision import transforms
2020-11-30 13:29:49 +00:00
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
2020-08-20 01:22:39 +00:00
import pytorch_lightning as pl
2022-01-13 21:11:43 +00:00
2020-09-21 15:17:59 +00:00
Step 1: Define LightningModule
2022-01-13 21:11:43 +00:00
==============================
2020-09-21 15:17:59 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-09-21 15:17:59 +00:00
2021-03-23 23:08:57 +00:00
class LitAutoEncoder(pl.LightningModule):
2020-08-20 01:22:39 +00:00
def __init__(self):
super().__init__()
2021-07-30 12:10:15 +00:00
self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
2020-08-20 01:22:39 +00:00
2020-09-22 18:00:02 +00:00
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
2020-09-21 15:17:59 +00:00
def training_step(self, batch, batch_idx):
2020-10-12 15:56:16 +00:00
# training_step defined the train loop.
# It is independent of forward
2020-09-21 15:17:59 +00:00
x, y = batch
2020-08-20 01:22:39 +00:00
x = x.view(x.size(0), -1)
2020-09-21 15:17:59 +00:00
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
2020-09-30 03:44:27 +00:00
# Logging to TensorBoard by default
2021-07-30 12:10:15 +00:00
self.log("train_loss", loss)
2020-09-21 15:17:59 +00:00
return loss
2020-08-30 13:31:36 +00:00
2020-08-20 01:22:39 +00:00
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
2020-08-30 13:31:36 +00:00
2022-01-13 21:11:43 +00:00
2020-10-14 17:43:58 +00:00
**SYSTEM VS MODEL**
2020-10-12 15:56:16 +00:00
2022-01-13 21:11:43 +00:00
A :doc: `lightning module <../common/lightning_module>` defines a *system* not just a model.
2020-10-12 15:13:26 +00:00
2020-10-12 15:56:16 +00:00
.. figure :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/model_system.png
:width: 400
2020-10-12 15:13:26 +00:00
2020-10-12 15:56:16 +00:00
Examples of systems are:
2020-08-30 10:51:34 +00:00
2022-01-13 21:11:43 +00:00
- `Autoencoder <https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/basic_examples/autoencoder.py> `_
2021-08-23 19:25:21 +00:00
- `BERT <https://colab.research.google.com/github/PyTorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/text-transformers.ipynb> `_
- `DQN <https://colab.research.google.com/github/PyTorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/reinforce-learning-DQN.ipynb> `_
- `GAN <https://colab.research.google.com/github/PyTorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/basic-gan.ipynb> `_
- `Image classifier <https://colab.research.google.com/github/PyTorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/mnist-hello-world.ipynb> `_
2022-01-13 21:11:43 +00:00
- `Semantic Segmentation <https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/semantic_segmentation.py> `_
2021-08-23 19:25:21 +00:00
- `and a lot more <https://github.com/PyTorchLightning/lightning-tutorials/tree/publication/.notebooks/lightning_examples> `_
2020-08-20 01:22:39 +00:00
2022-02-21 21:21:12 +00:00
Under the hood, a LightningModule is still just a :class: `torch.nn.Module` that groups all research code into a single file to make it self-contained:
2020-08-20 01:22:39 +00:00
- The Train loop
- The Validation loop
- The Test loop
2021-10-28 12:31:02 +00:00
- The Prediction loop
2020-10-14 17:43:58 +00:00
- The Model or system of Models
2022-01-13 21:11:43 +00:00
- The Optimizers and LR Schedulers
2020-08-20 01:22:39 +00:00
2020-09-21 15:17:59 +00:00
You can customize any part of training (such as the backward pass) by overriding any
2022-01-13 21:11:43 +00:00
of the 20+ hooks found in :ref: `lightning_hooks`
2020-08-20 01:22:39 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-08-20 01:22:39 +00:00
2022-01-13 21:11:43 +00:00
class LitAutoEncoder(pl.LightningModule):
2020-10-11 02:04:50 +00:00
def backward(self, loss, optimizer, optimizer_idx):
2020-09-21 15:17:59 +00:00
loss.backward()
2020-08-30 13:31:36 +00:00
2020-10-12 20:48:07 +00:00
**FORWARD vs TRAINING_STEP**
2022-02-21 22:08:26 +00:00
In Lightning we suggest separating training from inference. The `` training_step `` defines
2022-01-13 21:11:43 +00:00
the full training loop. We encourage users to use the `` forward `` to define inference actions.
2020-10-12 20:48:07 +00:00
2022-01-13 21:11:43 +00:00
For example, in this case we can define the autoencoder to act as an embedding extractor:
2020-09-23 11:36:51 +00:00
.. code-block :: python
2020-08-20 01:22:39 +00:00
2022-01-13 21:11:43 +00:00
def forward(self, batch):
embeddings = self.encoder(batch)
2020-10-12 20:48:07 +00:00
return embeddings
2020-09-23 11:36:51 +00:00
2022-02-21 21:21:12 +00:00
Of course, nothing is preventing you from using `` forward `` from within the `` training_step `` .
2020-10-12 20:48:07 +00:00
.. code-block :: python
def training_step(self, batch, batch_idx):
...
2022-02-03 03:34:27 +00:00
embeddings = self.encoder(batch)
2022-01-13 21:11:43 +00:00
output = self.decoder(embeddings)
2020-10-12 20:48:07 +00:00
2020-11-30 13:29:49 +00:00
It really comes down to your application. We do, however, recommend that you keep both intents separate.
2020-10-12 20:48:07 +00:00
2022-01-13 21:11:43 +00:00
* Use `` forward `` for inference (predicting).
* Use `` training_step `` for training.
2020-09-23 11:36:51 +00:00
2022-01-13 21:11:43 +00:00
More details in :doc: `LightningModule <../common/lightning_module>` docs.
2020-09-22 18:00:02 +00:00
2020-08-30 13:31:36 +00:00
----------
2020-09-22 10:00:54 +00:00
Step 2: Fit with Lightning Trainer
2022-01-13 21:11:43 +00:00
==================================
2020-08-30 13:31:36 +00:00
2021-10-28 12:31:02 +00:00
First, define the data however you want. Lightning just needs a :class: `~torch.utils.data.DataLoader` for the train/val/test/predict splits.
2020-08-30 15:01:16 +00:00
.. code-block :: python
2020-08-30 13:31:36 +00:00
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
2020-08-31 14:48:50 +00:00
train_loader = DataLoader(dataset)
2021-01-07 05:24:47 +00:00
2022-01-13 21:11:43 +00:00
Next, init the :doc: `LightningModule <../common/lightning_module>` and the PyTorch Lightning :doc: `Trainer <../common/trainer>` ,
2020-09-22 10:00:54 +00:00
then call fit with both the data and model.
2020-08-30 15:01:16 +00:00
.. code-block :: python
2020-08-30 13:31:36 +00:00
# init model
2020-09-22 10:00:54 +00:00
autoencoder = LitAutoEncoder()
2020-08-30 13:31:36 +00:00
# most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more)
2022-03-23 19:52:12 +00:00
# trainer = pl.Trainer(accelerator="gpu", devices=8) (if you have GPUs)
2020-08-30 13:31:36 +00:00
trainer = pl.Trainer()
2022-01-13 21:11:43 +00:00
trainer.fit(model=autoencoder, train_dataloaders=train_loader)
2020-09-22 10:00:54 +00:00
The :class: `~pytorch_lightning.trainer.Trainer` automates:
* Epoch and batch iteration
2022-01-13 21:11:43 +00:00
* `` optimizer.step() `` , `` loss.backward() `` , `` optimizer.zero_grad() `` calls
* Calling of `` model.eval() `` , enabling/disabling grads during evaluation
* :doc: `Checkpoint Saving and Loading <../common/checkpointing>`
2021-01-26 20:07:07 +00:00
* Tensorboard (see :doc: `loggers <../common/loggers>` options)
2022-01-06 13:42:44 +00:00
* :ref: `Multi-GPU <accelerators/gpu:Multi GPU Training>` support
* :doc: `TPU <../accelerators/tpu>`
2021-06-16 21:28:51 +00:00
* :ref: `16-bit precision AMP <amp>` support
2020-09-22 10:00:54 +00:00
2022-02-21 21:21:12 +00:00
.. tip :: If you prefer to manually manage optimizers, you can use the :ref: `manual_opt` mode (i.e., RL, GANs, and so on).
2020-10-11 17:12:19 +00:00
**That's it!**
2022-02-21 21:21:12 +00:00
These are the main two components you need to know in Lightning in general. All the other features of Lightning are either
2022-01-13 21:11:43 +00:00
features of the Trainer or LightningModule or are extensions for advanced use-cases.
2020-10-11 17:12:19 +00:00
2020-09-22 10:00:54 +00:00
-----------
2020-08-30 13:31:36 +00:00
2020-10-11 17:12:19 +00:00
***** ***** *** *
2022-01-13 21:11:43 +00:00
Basic Features
2020-10-11 17:12:19 +00:00
***** ***** *** *
2022-01-13 21:11:43 +00:00
Manual vs Automatic Optimization
2020-10-12 20:48:07 +00:00
================================
2022-01-13 21:11:43 +00:00
Automatic Optimization
2020-10-12 20:48:07 +00:00
----------------------
2022-01-13 21:11:43 +00:00
2020-11-30 13:29:49 +00:00
With Lightning, you don't need to worry about when to enable/disable grads, do a backward pass, or update optimizers
2022-01-13 21:11:43 +00:00
as long as you return a loss with an attached graph from the :meth: `~pytorch_lightning.core.lightning.LightningModule.training_step` method,
Lightning will automate the optimization.
2020-10-12 20:48:07 +00:00
.. code-block :: python
def training_step(self, batch, batch_idx):
2021-03-01 13:36:46 +00:00
loss = self.encoder(batch)
2020-10-12 20:48:07 +00:00
return loss
.. _manual_opt:
2022-01-13 21:11:43 +00:00
Manual Optimization
2020-10-12 20:48:07 +00:00
-------------------
2022-01-13 21:11:43 +00:00
2022-02-21 21:21:12 +00:00
For certain research like GANs, reinforcement learning, or something with multiple optimizers
2022-01-13 21:11:43 +00:00
or an inner loop, you can turn off automatic optimization and fully control it yourself.
2020-10-12 20:48:07 +00:00
2022-02-21 21:21:12 +00:00
Turn off automatic optimization, and you control the optimization!
2020-10-12 20:48:07 +00:00
.. code-block :: python
2021-02-18 19:51:56 +00:00
def __init__(self):
self.automatic_optimization = False
2021-07-30 12:10:15 +00:00
2021-03-07 07:48:50 +00:00
def training_step(self, batch, batch_idx):
2021-10-28 12:31:02 +00:00
# access your optimizers with use_pl_optimizer=False. Default is True,
# setting use_pl_optimizer=True will maintain plugin/precision support
2021-03-01 13:36:46 +00:00
opt_a, opt_b = self.optimizers(use_pl_optimizer=True)
2020-10-12 20:48:07 +00:00
2021-03-01 13:36:46 +00:00
loss_a = self.generator(batch)
2020-10-12 20:48:07 +00:00
opt_a.zero_grad()
2021-03-01 13:36:46 +00:00
# use `manual_backward()` instead of `loss.backward` to automate half precision, etc...
self.manual_backward(loss_a)
opt_a.step()
2020-10-12 20:48:07 +00:00
2021-03-01 13:36:46 +00:00
loss_b = self.discriminator(batch)
opt_b.zero_grad()
self.manual_backward(loss_b)
opt_b.step()
2020-10-12 20:48:07 +00:00
2022-01-13 21:11:43 +00:00
Loop Customization
2021-10-18 09:43:11 +00:00
==================
2022-01-13 21:11:43 +00:00
If you need even more flexibility, you can fully customize the training loop to its core. These are usually required to be customized
for advanced use-cases. Learn more inside :doc: `Loops docs <../extensions/loops>` .
2020-10-11 17:12:19 +00:00
2020-09-21 15:17:59 +00:00
Predict or Deploy
2020-10-11 17:12:19 +00:00
=================
2022-01-13 21:11:43 +00:00
2022-02-21 21:21:12 +00:00
When you're done training, you have three options to use your LightningModule for predictions.
2020-09-21 15:17:59 +00:00
2020-09-22 10:00:54 +00:00
Option 1: Sub-models
2020-10-11 17:12:19 +00:00
--------------------
2022-01-13 21:11:43 +00:00
2020-09-22 10:00:54 +00:00
Pull out any model inside your system for predictions.
2020-09-21 15:17:59 +00:00
.. code-block :: python
# ----------------------------------
# to use as embedding extractor
# ----------------------------------
2021-07-30 12:10:15 +00:00
autoencoder = LitAutoEncoder.load_from_checkpoint("path/to/checkpoint_file.ckpt")
2020-09-22 10:00:54 +00:00
encoder_model = autoencoder.encoder
encoder_model.eval()
2020-09-21 15:17:59 +00:00
# ----------------------------------
# to use as image generator
# ----------------------------------
2020-09-22 10:00:54 +00:00
decoder_model = autoencoder.decoder
decoder_model.eval()
2020-09-21 15:17:59 +00:00
2022-01-13 21:11:43 +00:00
2020-09-22 10:00:54 +00:00
Option 2: Forward
2020-10-11 17:12:19 +00:00
-----------------
2022-01-13 21:11:43 +00:00
2020-09-22 10:00:54 +00:00
You can also add a forward method to do predictions however you want.
2020-09-21 15:17:59 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-09-21 15:17:59 +00:00
# ----------------------------------
# using the AE to extract embeddings
# ----------------------------------
2021-01-26 09:44:54 +00:00
class LitAutoEncoder(LightningModule):
def __init__(self):
super().__init__()
2022-01-13 21:11:43 +00:00
self.encoder = nn.Sequential(nn.Linear(28 * 28, 64))
2021-01-26 09:44:54 +00:00
2020-09-21 15:17:59 +00:00
def forward(self, x):
embedding = self.encoder(x)
2020-09-22 10:00:54 +00:00
return embedding
2020-09-21 15:17:59 +00:00
2021-07-30 12:10:15 +00:00
2021-01-26 09:44:54 +00:00
autoencoder = LitAutoEncoder()
2021-10-28 12:31:02 +00:00
embedding = autoencoder(torch.rand(1, 28 * 28))
2020-11-30 13:29:49 +00:00
2020-09-22 10:00:54 +00:00
.. code-block :: python
2020-09-21 15:17:59 +00:00
2022-02-21 21:21:12 +00:00
# -------------------------------
# using the AE to generate images
# -------------------------------
2021-01-26 09:44:54 +00:00
class LitAutoEncoder(LightningModule):
def __init__(self):
super().__init__()
2022-01-13 21:11:43 +00:00
self.decoder = nn.Sequential(nn.Linear(64, 28 * 28))
2021-01-26 09:44:54 +00:00
2020-09-21 15:17:59 +00:00
def forward(self):
2022-01-13 21:11:43 +00:00
z = torch.rand(1, 64)
2020-09-21 15:17:59 +00:00
image = self.decoder(z)
image = image.view(1, 1, 28, 28)
return image
2021-07-30 12:10:15 +00:00
2021-01-26 09:44:54 +00:00
autoencoder = LitAutoEncoder()
2020-11-30 11:00:14 +00:00
image_sample = autoencoder()
2020-09-21 15:17:59 +00:00
2022-01-13 21:11:43 +00:00
2020-09-22 10:00:54 +00:00
Option 3: Production
2020-10-11 17:12:19 +00:00
--------------------
2022-01-13 21:11:43 +00:00
2022-02-21 21:21:12 +00:00
For production systems, `ONNX <https://pytorch.org/docs/stable/onnx.html> `_ or `TorchScript <https://pytorch.org/docs/stable/jit.html> `_ is much faster.
2022-01-13 21:11:43 +00:00
Make sure you have added a `` forward `` method or trace only the sub-models you need.
* TorchScript using :meth: `~pytorch_lightning.core.lightning.LightningModule.to_torchscript` method.
2020-09-21 15:17:59 +00:00
.. code-block :: python
2020-09-22 10:00:54 +00:00
autoencoder = LitAutoEncoder()
2022-01-13 21:11:43 +00:00
autoencoder.to_torchscript(file_path="model.pt")
* Onnx using :meth: `~pytorch_lightning.core.lightning.LightningModule.to_onnx` method.
2020-09-21 15:17:59 +00:00
2020-09-22 10:00:54 +00:00
.. code-block :: python
2022-01-13 21:11:43 +00:00
autoencoder = LitAutoEncoder()
input_sample = torch.randn((1, 28 * 28))
autoencoder.to_onnx(file_path="model.onnx", input_sample=input_sample, export_params=True)
2020-09-21 15:17:59 +00:00
2020-09-22 10:00:54 +00:00
2022-01-13 21:11:43 +00:00
Using Accelerators
==================
2022-02-21 21:21:12 +00:00
It's easy to use CPUs, GPUs, TPUs or IPUs in Lightning. There's **no need** to change your code; simply change the :class: `~pytorch_lightning.trainer.trainer.Trainer` options.
2022-01-13 21:11:43 +00:00
CPU
---
2020-09-22 10:00:54 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-09-22 10:00:54 +00:00
# train on CPU
2021-01-26 09:44:54 +00:00
trainer = Trainer()
2020-09-22 10:00:54 +00:00
# train on 8 CPUs
2021-01-26 09:44:54 +00:00
trainer = Trainer(num_processes=8)
2020-09-22 10:00:54 +00:00
# train on 1024 CPUs across 128 machines
2021-07-30 12:10:15 +00:00
trainer = pl.Trainer(num_processes=8, num_nodes=128)
2020-09-22 10:00:54 +00:00
2022-01-13 21:11:43 +00:00
GPU
---
2020-09-22 10:00:54 +00:00
.. code-block :: python
# train on 1 GPU
2022-03-23 19:52:12 +00:00
trainer = pl.Trainer(accelerator="gpu", devices=1)
2021-01-07 05:24:47 +00:00
2020-09-22 10:00:54 +00:00
# train on multiple GPUs across nodes (32 gpus here)
2022-03-23 19:52:12 +00:00
trainer = pl.Trainer(accelerator="gpu", devices=4, num_nodes=8)
2021-01-07 05:24:47 +00:00
2020-09-22 10:00:54 +00:00
# train on gpu 1, 3, 5 (3 gpus total)
2022-03-23 19:52:12 +00:00
trainer = pl.Trainer(accelerator="gpu", devices=[1, 3, 5])
2020-09-22 10:00:54 +00:00
# Multi GPU with mixed precision
2022-03-23 19:52:12 +00:00
trainer = pl.Trainer(accelerator="gpu", devices=2, precision=16)
2020-09-22 10:00:54 +00:00
2022-01-13 21:11:43 +00:00
TPU
---
2020-09-22 10:00:54 +00:00
.. code-block :: python
2022-01-13 21:11:43 +00:00
# Train on 8 TPU cores
2022-03-25 11:57:02 +00:00
trainer = pl.Trainer(accelerator="tpu", devices=8)
2020-09-22 10:00:54 +00:00
2022-01-13 21:11:43 +00:00
# Train on single TPU core
2022-03-25 11:57:02 +00:00
trainer = pl.Trainer(accelerator="tpu", devices=1)
2020-09-22 10:00:54 +00:00
2022-01-13 21:11:43 +00:00
# Train on 7th TPU core
2022-03-25 11:57:02 +00:00
trainer = pl.Trainer(accelerator="tpu", devices=[7])
2020-09-22 10:00:54 +00:00
2022-01-13 21:11:43 +00:00
# without changing a SINGLE line of your code, you can
2022-02-21 21:21:12 +00:00
# train on TPUs using 16-bit precision
2020-09-22 10:00:54 +00:00
# using only half the training data and checking validation every quarter of a training epoch
2022-03-25 11:57:02 +00:00
trainer = pl.Trainer(accelerator="tpu", devices=8, precision=16, limit_train_batches=0.5, val_check_interval=0.25)
2021-01-07 05:24:47 +00:00
2022-01-13 21:11:43 +00:00
IPU
---
2021-10-28 12:31:02 +00:00
.. code-block :: python
# Train on IPUs
trainer = pl.Trainer(ipus=8)
2020-08-30 13:31:36 +00:00
2022-01-13 21:11:43 +00:00
Checkpointing
=============
2020-09-22 10:00:54 +00:00
Lightning automatically saves your model. Once you've trained, you can load the checkpoints as follows:
2020-08-20 01:22:39 +00:00
2020-08-30 15:01:16 +00:00
.. code-block :: python
2020-08-30 10:51:34 +00:00
2022-01-13 21:11:43 +00:00
model = LitModel.load_from_checkpoint(path_to_saved_checkpoint)
2020-08-30 10:51:34 +00:00
2020-09-22 10:00:54 +00:00
The above checkpoint contains all the arguments needed to init the model and set the state dict.
2020-08-30 15:01:16 +00:00
If you prefer to do it manually, here's the equivalent
2020-08-30 10:51:34 +00:00
2020-08-30 15:01:16 +00:00
.. code-block :: python
2020-08-30 10:51:34 +00:00
2020-08-30 15:01:16 +00:00
# load the ckpt
2021-07-30 12:10:15 +00:00
ckpt = torch.load("path/to/checkpoint.ckpt")
2020-08-20 01:22:39 +00:00
2020-08-30 15:01:16 +00:00
# equivalent to the above
model = LitModel()
2021-07-30 12:10:15 +00:00
model.load_state_dict(ckpt["state_dict"])
2020-08-20 01:22:39 +00:00
2022-01-13 21:11:43 +00:00
Learn more inside :ref: `Checkpoint docs <checkpointing>` .
2020-08-30 13:31:36 +00:00
2022-01-13 21:11:43 +00:00
Data Flow
2020-10-11 17:12:19 +00:00
=========
2022-01-13 21:11:43 +00:00
2021-10-28 12:31:02 +00:00
Each loop (training, validation, test, predict) has three hooks you can implement:
2020-10-03 12:15:07 +00:00
2020-09-30 12:31:16 +00:00
- x_step
2022-01-13 21:11:43 +00:00
- x_step_end (optional)
- x_epoch_end (optional)
2020-09-30 03:44:27 +00:00
2022-02-21 21:21:12 +00:00
To illustrate how data flows, we'll use the training loop (i.e., x=training)
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
.. code-block :: python
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
outs = []
for batch in data:
out = training_step(batch)
2022-01-13 21:11:43 +00:00
out = training_step_end(out)
2020-09-30 12:31:16 +00:00
outs.append(out)
training_epoch_end(outs)
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
The equivalent in Lightning is:
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
.. code-block :: python
2020-09-21 15:17:59 +00:00
2020-09-30 12:31:16 +00:00
def training_step(self, batch, batch_idx):
prediction = ...
return prediction
2020-08-20 01:22:39 +00:00
2021-07-30 12:10:15 +00:00
2021-10-28 12:31:02 +00:00
def training_epoch_end(self, outs):
for out in outs:
2021-07-30 12:10:15 +00:00
...
2020-08-20 01:22:39 +00:00
2022-02-21 21:21:12 +00:00
In the event you use DP or DDP2 distributed modes (i.e., split a batch across devices), check out *Training with DataParallel* section :ref: `here <lightning_module>` .
2022-01-13 21:11:43 +00:00
The validation, test and prediction loops have the same structure.
2021-07-30 12:10:15 +00:00
2020-09-22 10:00:54 +00:00
2020-09-30 12:31:16 +00:00
----------------
2022-01-13 21:11:43 +00:00
***** ***** ***** *** *
Optional Extensions
***** ***** ***** *** *
2020-11-07 20:41:02 +00:00
2022-01-13 21:11:43 +00:00
Check out the following optional extensions that can make your ML Pipelines more robust:
2021-07-30 12:10:15 +00:00
2022-01-13 21:11:43 +00:00
* :ref: `LightningDataModule <datamodules>`
* :ref: `Callbacks <callbacks>`
* :ref: `Logging <logging>`
* :ref: `Accelerators <accelerators>`
* :ref: `Plugins <plugins>`
* :ref: `Loops <loop_customization>`
2020-09-22 10:00:54 +00:00
2022-01-13 21:11:43 +00:00
----------------
2020-09-22 10:00:54 +00:00
2022-01-13 21:11:43 +00:00
***** *** *
Debugging
***** *** *
2020-09-22 10:00:54 +00:00
2022-01-13 21:11:43 +00:00
Lightning has many tools for debugging. Here is an example of just a few of them:
2020-08-20 01:22:39 +00:00
2022-01-13 21:11:43 +00:00
Limit Batches
=============
2020-08-20 01:22:39 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-08-20 01:22:39 +00:00
2022-02-21 21:21:12 +00:00
# use only 10 train batches and three val batches per epoch
2022-01-13 21:11:43 +00:00
trainer = Trainer(limit_train_batches=10, limit_val_batches=3)
# use 20% of total train batches and 10% of total val batches per epoch
trainer = Trainer(limit_train_batches=0.2, limit_val_batches=0.1)
2020-08-30 15:01:16 +00:00
2022-01-13 21:11:43 +00:00
Overfit Batches
===============
2020-08-30 13:31:36 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-08-30 13:31:36 +00:00
2022-01-13 21:11:43 +00:00
# Automatically overfit the same batches to your model for a sanity test
# use only 10 train batches
trainer = Trainer(overfit_batches=10)
# use only 20% of total train batches
trainer = Trainer(overfit_batches=0.2)
2020-08-30 15:01:16 +00:00
2022-01-13 21:11:43 +00:00
Fast Dev Run
============
2020-08-30 15:01:16 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-08-30 15:01:16 +00:00
2021-10-28 12:31:02 +00:00
# unit test all the code - hits every line of your code once to see if you have bugs,
2022-01-13 21:11:43 +00:00
# instead of waiting hours to crash somewhere
2021-01-26 09:44:54 +00:00
trainer = Trainer(fast_dev_run=True)
2020-08-20 01:22:39 +00:00
2022-02-21 21:21:12 +00:00
# unit test all the code - hits every line of your code with four batches
2021-10-28 12:31:02 +00:00
trainer = Trainer(fast_dev_run=4)
2022-01-13 21:11:43 +00:00
Val Check Interval
==================
2020-08-20 01:22:39 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-08-20 01:22:39 +00:00
2020-10-08 11:05:26 +00:00
# run validation every 25% of a training epoch
2021-01-26 09:44:54 +00:00
trainer = Trainer(val_check_interval=0.25)
2020-08-20 01:22:39 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2021-01-07 05:24:47 +00:00
2020-09-22 10:00:54 +00:00
# Profile your code to find speed/memory bottlenecks
2021-02-24 09:08:21 +00:00
Trainer(profiler="simple")
2021-01-07 05:24:47 +00:00
2022-01-13 21:11:43 +00:00
2020-08-20 01:22:39 +00:00
---------------
2022-01-13 21:11:43 +00:00
2021-04-30 11:45:17 +00:00
***** ***** ***** *** *
2022-01-13 21:11:43 +00:00
Other Cool Features
2021-04-30 11:45:17 +00:00
***** ***** ***** *** *
2020-08-20 01:22:39 +00:00
2022-02-21 21:21:12 +00:00
Once you define and train your first Lightning model, you might want to try other cool features like:
2020-08-20 01:22:39 +00:00
2021-01-26 20:07:07 +00:00
- :doc: `Automatic early stopping <../common/early_stopping>`
2021-08-11 03:26:01 +00:00
- :ref: `Automatic truncated-back-propagation-through-time <common/lightning_module:truncated_bptt_steps>`
2022-01-12 16:26:03 +00:00
- :ref: `Automatically scale your batch size <advanced/training_tricks:Batch Size Finder>`
- :ref: `Automatically scale your batch size <advanced/training_tricks:Learning Rate Finder>`
2021-12-16 16:21:59 +00:00
- :ref: `Load checkpoints directly from S3 <common/checkpointing:Checkpoint Loading>`
2021-04-21 23:38:16 +00:00
- :doc: `Scale to massive compute clusters <../clouds/cluster>`
2021-10-28 12:31:02 +00:00
- :doc: `Use multiple dataloaders per train/val/test/predict loop <../guides/data>`
2022-03-03 10:19:05 +00:00
- :ref: `Use multiple optimizers to do reinforcement learning or even GANs <common/optimization:Use multiple optimizers (like GANs)>`
2020-09-14 01:04:21 +00:00
2022-03-03 10:19:05 +00:00
Read our :doc: `Guide <../starter/core_guide>` to learn more with a step-by-step walk-through!
2022-01-13 21:11:43 +00:00
2020-08-20 01:22:39 +00:00
-------------
2022-01-13 21:11:43 +00:00
***** **
2020-10-11 17:30:25 +00:00
Grid AI
2022-01-13 21:11:43 +00:00
***** **
2021-04-30 11:45:17 +00:00
Grid AI is our native solution for large scale training and tuning on the cloud.
2020-10-11 17:30:25 +00:00
2021-04-30 11:45:17 +00:00
`Get started for free with your GitHub or Google Account here <https://www.grid.ai/> `_ .
2020-10-11 17:30:25 +00:00
2022-01-13 21:11:43 +00:00
2020-10-11 17:30:25 +00:00
------------
2022-01-13 21:11:43 +00:00
***** *** *
2020-09-30 12:31:16 +00:00
Community
2022-01-13 21:11:43 +00:00
***** *** *
2020-10-03 12:15:07 +00:00
Our community of core maintainers and thousands of expert researchers is active on our
2021-05-10 08:50:14 +00:00
`Slack <https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ> `_
2021-02-24 18:58:23 +00:00
and `GitHub Discussions <https://github.com/PyTorchLightning/pytorch-lightning/discussions> `_ . Drop by
to hang out, ask Lightning questions or even discuss research!
2020-09-22 10:00:54 +00:00
2020-10-11 17:30:25 +00:00
-------------
2022-01-13 21:11:43 +00:00
***** ***** *
2020-08-20 01:22:39 +00:00
Masterclass
2022-01-13 21:11:43 +00:00
***** ***** *
2020-09-30 12:31:16 +00:00
We also offer a Masterclass to teach you the advanced uses of Lightning.
2020-08-20 01:22:39 +00:00
2021-02-03 15:08:19 +00:00
.. image :: ../_static/images/general/PTL101_youtube_thumbnail.jpg
2020-08-20 01:22:39 +00:00
:width: 500
:align: center
:alt: Masterclass
:target: https://www.youtube.com/playlist?list=PLaMu-SDt_RB5NUm67hU2pdE75j6KaIOv2