2020-08-20 01:22:39 +00:00
.. testsetup :: *
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
2020-11-30 13:29:49 +00:00
import pytorch_lightning as pl
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.trainer import Trainer
2020-08-20 01:22:39 +00:00
2020-09-14 01:04:21 +00:00
.. _new_project:
2020-08-20 01:22:39 +00:00
####################
2020-08-30 13:31:36 +00:00
Lightning in 2 steps
2020-08-20 01:22:39 +00:00
####################
2020-08-30 15:01:16 +00:00
**In this guide we'll show you how to organize your PyTorch code into Lightning in 2 steps.**
2020-08-20 01:22:39 +00:00
Organizing your code with PyTorch Lightning makes your code:
* Keep all the flexibility (this is all pure PyTorch), but removes a ton of boilerplate
* More readable by decoupling the research code from the engineering
* Easier to reproduce
2020-11-30 13:29:49 +00:00
* Less error-prone by automating most of the training loop and tricky engineering
2020-08-20 01:22:39 +00:00
* Scalable to any hardware without changing your model
----------
2020-10-08 09:49:56 +00:00
Here's a 3 minute conversion guide for PyTorch projects:
2020-08-20 01:22:39 +00:00
.. raw :: html
2020-10-09 02:55:59 +00:00
<video width="100%" max-width="800px" controls autoplay muted playsinline
2020-10-08 09:49:56 +00:00
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_docs_animation_final.m4v"></video>
2020-08-20 01:22:39 +00:00
----------
***** ***** ***** ***** ***** ***** ***
Step 0: Install PyTorch Lightning
***** ***** ***** ***** ***** ***** ***
2020-08-30 13:31:36 +00:00
You can install using `pip <https://pypi.org/project/pytorch-lightning/> `_
2020-08-20 01:22:39 +00:00
.. code-block :: bash
pip install pytorch-lightning
2020-08-30 13:31:36 +00:00
2020-08-20 01:22:39 +00:00
Or with `conda <https://anaconda.org/conda-forge/pytorch-lightning> `_ (see how to install conda `here <https://docs.conda.io/projects/conda/en/latest/user-guide/install/> `_ ):
.. code-block :: bash
conda install pytorch-lightning -c conda-forge
You could also use conda environments
.. code-block :: bash
conda activate my_env
pip install pytorch-lightning
----------
2020-09-21 15:17:59 +00:00
Import the following:
2020-08-20 01:22:39 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
:skipif: not _TORCHVISION_AVAILABLE
2020-08-20 01:22:39 +00:00
import os
import torch
2020-09-21 15:17:59 +00:00
from torch import nn
2020-08-20 01:22:39 +00:00
import torch.nn.functional as F
from torchvision import transforms
2020-11-30 13:29:49 +00:00
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
2020-08-20 01:22:39 +00:00
import pytorch_lightning as pl
2020-09-21 15:17:59 +00:00
***** ***** ***** ***** ***** *****
Step 1: Define LightningModule
***** ***** ***** ***** ***** *****
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-09-21 15:17:59 +00:00
2021-01-26 09:44:54 +00:00
class LitAutoEncoder(LightningModule):
2020-08-20 01:22:39 +00:00
def __init__(self):
super().__init__()
2020-09-22 10:00:54 +00:00
self.encoder = nn.Sequential(
nn.Linear(28*28, 64),
nn.ReLU(),
nn.Linear(64, 3)
)
self.decoder = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 28*28)
)
2020-08-20 01:22:39 +00:00
2020-09-22 18:00:02 +00:00
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
2020-09-21 15:17:59 +00:00
def training_step(self, batch, batch_idx):
2020-10-12 15:56:16 +00:00
# training_step defined the train loop.
# It is independent of forward
2020-09-21 15:17:59 +00:00
x, y = batch
2020-08-20 01:22:39 +00:00
x = x.view(x.size(0), -1)
2020-09-21 15:17:59 +00:00
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
2020-09-30 03:44:27 +00:00
# Logging to TensorBoard by default
self.log('train_loss', loss)
2020-09-21 15:17:59 +00:00
return loss
2020-08-30 13:31:36 +00:00
2020-08-20 01:22:39 +00:00
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
2020-08-30 13:31:36 +00:00
2020-10-14 17:43:58 +00:00
**SYSTEM VS MODEL**
2020-10-12 15:56:16 +00:00
2020-10-13 10:42:33 +00:00
A :ref: `lightning_module` defines a *system* not a model.
2020-10-12 15:13:26 +00:00
2020-10-12 15:56:16 +00:00
.. figure :: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/model_system.png
:width: 400
2020-10-12 15:13:26 +00:00
2020-10-12 15:56:16 +00:00
Examples of systems are:
2020-08-30 10:51:34 +00:00
2020-09-22 10:00:54 +00:00
- `Autoencoder <https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/master/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py> `_
- `BERT <https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=yr7eaxkF-djf> `_
- `DQN <https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=IAlT0-75T_Kv> `_
- `GAN <https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/master/pl_bolts/models/gans/basic/basic_gan_module.py> `_
- `Image classifier <https://colab.research.google.com/drive/1F_RNcHzTfFuQf-LeKvSlud6x7jXYkG31#scrollTo=gEulmrbxwaYL> `_
2021-01-07 05:24:47 +00:00
- Seq2seq
2020-09-22 10:00:54 +00:00
- `SimCLR <https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/master/pl_bolts/models/self_supervised/simclr/simclr_module.py> `_
- `VAE <https://github.com/PyTorchLightning/pytorch-lightning-bolts/blob/master/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py> `_
2020-08-20 01:22:39 +00:00
2020-10-12 15:56:16 +00:00
Under the hood a LightningModule is still just a :class: `torch.nn.Module` that groups all research code into a single file to make it self-contained:
2020-08-20 01:22:39 +00:00
- The Train loop
- The Validation loop
- The Test loop
2020-10-14 17:43:58 +00:00
- The Model or system of Models
2020-08-20 01:22:39 +00:00
- The Optimizer
2020-09-21 15:17:59 +00:00
You can customize any part of training (such as the backward pass) by overriding any
of the 20+ hooks found in :ref: `hooks`
2020-08-20 01:22:39 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-08-20 01:22:39 +00:00
2021-01-26 09:44:54 +00:00
class LitAutoEncoder(LightningModule):
2020-08-20 01:22:39 +00:00
2020-10-11 02:04:50 +00:00
def backward(self, loss, optimizer, optimizer_idx):
2020-09-21 15:17:59 +00:00
loss.backward()
2020-08-30 13:31:36 +00:00
2020-10-12 20:48:07 +00:00
**FORWARD vs TRAINING_STEP**
In Lightning we separate training from inference. The training_step defines
the full training loop. We encourage users to use the forward to define inference
actions.
For example, in this case we could define the autoencoder to act as an embedding extractor:
2020-09-23 11:36:51 +00:00
.. code-block :: python
2020-08-20 01:22:39 +00:00
2020-10-12 20:48:07 +00:00
def forward(self, x):
embeddings = self.encoder(x)
return embeddings
2020-09-23 11:36:51 +00:00
2020-11-30 13:29:49 +00:00
Of course, nothing is stopping you from using forward from within the training_step.
2020-10-12 20:48:07 +00:00
.. code-block :: python
def training_step(self, batch, batch_idx):
...
z = self(x)
2020-11-30 13:29:49 +00:00
It really comes down to your application. We do, however, recommend that you keep both intents separate.
2020-10-12 20:48:07 +00:00
* Use forward for inference (predicting).
* Use training_step for training.
2020-09-23 11:36:51 +00:00
More details in :ref: `lightning_module` docs.
2020-09-22 18:00:02 +00:00
2020-08-30 13:31:36 +00:00
----------
2020-09-22 10:00:54 +00:00
***** ***** ***** ***** ***** ***** *** *
Step 2: Fit with Lightning Trainer
***** ***** ***** ***** ***** ***** *** *
2020-08-30 13:31:36 +00:00
2020-09-22 10:00:54 +00:00
First, define the data however you want. Lightning just needs a :class: `~torch.utils.data.DataLoader` for the train/val/test splits.
2020-08-30 15:01:16 +00:00
.. code-block :: python
2020-08-30 13:31:36 +00:00
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
2020-08-31 14:48:50 +00:00
train_loader = DataLoader(dataset)
2021-01-07 05:24:47 +00:00
2020-10-13 10:42:33 +00:00
Next, init the :ref: `lightning_module` and the PyTorch Lightning :class: `~pytorch_lightning.trainer.Trainer` ,
2020-09-22 10:00:54 +00:00
then call fit with both the data and model.
2020-08-30 15:01:16 +00:00
.. code-block :: python
2020-08-30 13:31:36 +00:00
# init model
2020-09-22 10:00:54 +00:00
autoencoder = LitAutoEncoder()
2020-08-30 13:31:36 +00:00
# most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more)
2020-09-21 15:17:59 +00:00
# trainer = pl.Trainer(gpus=8) (if you have GPUs)
2020-08-30 13:31:36 +00:00
trainer = pl.Trainer()
2020-09-22 10:00:54 +00:00
trainer.fit(autoencoder, train_loader)
The :class: `~pytorch_lightning.trainer.Trainer` automates:
* Epoch and batch iteration
* Calling of optimizer.step(), backward, zero_grad()
* Calling of .eval(), enabling/disabling grads
* :ref: `weights_loading`
* Tensorboard (see :ref: `loggers` options)
* :ref: `multi_gpu` support
* :ref: `tpu`
* :ref: `amp` support
2020-10-11 17:12:19 +00:00
.. tip :: If you prefer to manually manage optimizers you can use the :ref: `manual_opt` mode (ie: RL, GANs, etc...).
---------
**That's it!**
These are the main 2 concepts you need to know in Lightning. All the other features of lightning are either
features of the Trainer or LightningModule.
2020-09-22 10:00:54 +00:00
-----------
2020-08-30 13:31:36 +00:00
2020-10-11 17:12:19 +00:00
***** ***** *** *
Basic features
***** ***** *** *
2020-10-12 20:48:07 +00:00
Manual vs automatic optimization
================================
Automatic optimization
----------------------
2020-11-30 13:29:49 +00:00
With Lightning, you don't need to worry about when to enable/disable grads, do a backward pass, or update optimizers
2020-10-12 20:48:07 +00:00
as long as you return a loss with an attached graph from the `training_step` , Lightning will automate the optimization.
.. code-block :: python
def training_step(self, batch, batch_idx):
loss = self.encoder(batch[0])
return loss
.. _manual_opt:
Manual optimization
-------------------
2020-11-30 13:29:49 +00:00
However, for certain research like GANs, reinforcement learning, or something with multiple optimizers
2020-10-12 20:48:07 +00:00
or an inner loop, you can turn off automatic optimization and fully control the training loop yourself.
First, turn off automatic optimization:
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-10-12 20:48:07 +00:00
trainer = Trainer(automatic_optimization=False)
Now you own the train loop!
.. code-block :: python
def training_step(self, batch, batch_idx, opt_idx):
2021-01-08 21:13:12 +00:00
# access your optimizers with use_pl_optimizer=False. Default is True
(opt_a, opt_b, opt_c) = self.optimizers(use_pl_optimizer=True)
2020-10-12 20:48:07 +00:00
loss_a = self.generator(batch[0])
# use this instead of loss.backward so we can automate half precision, etc...
self.manual_backward(loss_a, opt_a, retain_graph=True)
self.manual_backward(loss_a, opt_a)
opt_a.step()
opt_a.zero_grad()
loss_b = self.discriminator(batch[0])
self.manual_backward(loss_b, opt_b)
...
2020-10-11 17:12:19 +00:00
2020-09-21 15:17:59 +00:00
Predict or Deploy
2020-10-11 17:12:19 +00:00
=================
2020-09-21 15:17:59 +00:00
When you're done training, you have 3 options to use your LightningModule for predictions.
2020-09-22 10:00:54 +00:00
Option 1: Sub-models
2020-10-11 17:12:19 +00:00
--------------------
2020-09-22 10:00:54 +00:00
Pull out any model inside your system for predictions.
2020-09-21 15:17:59 +00:00
.. code-block :: python
# ----------------------------------
# to use as embedding extractor
# ----------------------------------
autoencoder = LitAutoEncoder.load_from_checkpoint('path/to/checkpoint_file.ckpt')
2020-09-22 10:00:54 +00:00
encoder_model = autoencoder.encoder
encoder_model.eval()
2020-09-21 15:17:59 +00:00
# ----------------------------------
# to use as image generator
# ----------------------------------
2020-09-22 10:00:54 +00:00
decoder_model = autoencoder.decoder
decoder_model.eval()
2020-09-21 15:17:59 +00:00
2020-09-22 10:00:54 +00:00
Option 2: Forward
2020-10-11 17:12:19 +00:00
-----------------
2020-09-22 10:00:54 +00:00
You can also add a forward method to do predictions however you want.
2020-09-21 15:17:59 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-09-21 15:17:59 +00:00
# ----------------------------------
# using the AE to extract embeddings
# ----------------------------------
2021-01-26 09:44:54 +00:00
class LitAutoEncoder(LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential()
2020-09-21 15:17:59 +00:00
def forward(self, x):
embedding = self.encoder(x)
2020-09-22 10:00:54 +00:00
return embedding
2020-09-21 15:17:59 +00:00
2021-01-26 09:44:54 +00:00
autoencoder = LitAutoEncoder()
2020-09-21 15:17:59 +00:00
autoencoder = autoencoder(torch.rand(1, 28 * 28))
2020-11-30 13:29:49 +00:00
2020-09-22 10:00:54 +00:00
.. code-block :: python
2020-09-21 15:17:59 +00:00
# ----------------------------------
# or using the AE to generate images
# ----------------------------------
2021-01-26 09:44:54 +00:00
class LitAutoEncoder(LightningModule):
def __init__(self):
super().__init__()
self.decoder = nn.Sequential()
2020-09-21 15:17:59 +00:00
def forward(self):
2020-09-22 10:00:54 +00:00
z = torch.rand(1, 3)
2020-09-21 15:17:59 +00:00
image = self.decoder(z)
image = image.view(1, 1, 28, 28)
return image
2021-01-26 09:44:54 +00:00
autoencoder = LitAutoEncoder()
2020-11-30 11:00:14 +00:00
image_sample = autoencoder()
2020-09-21 15:17:59 +00:00
2020-09-22 10:00:54 +00:00
Option 3: Production
2020-10-11 17:12:19 +00:00
--------------------
2020-11-30 13:29:49 +00:00
For production systems, onnx or torchscript are much faster. Make sure you have added
2020-09-22 10:00:54 +00:00
a forward method or trace only the sub-models you need.
2020-09-21 15:17:59 +00:00
.. code-block :: python
# ----------------------------------
# torchscript
# ----------------------------------
2020-09-22 10:00:54 +00:00
autoencoder = LitAutoEncoder()
torch.jit.save(autoencoder.to_torchscript(), "model.pt")
2020-09-21 15:17:59 +00:00
os.path.isfile("model.pt")
2020-09-22 10:00:54 +00:00
.. code-block :: python
2020-09-21 15:17:59 +00:00
# ----------------------------------
# onnx
# ----------------------------------
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
2020-09-22 10:00:54 +00:00
autoencoder = LitAutoEncoder()
2020-09-21 15:17:59 +00:00
input_sample = torch.randn((1, 28 * 28))
2020-09-22 10:00:54 +00:00
autoencoder.to_onnx(tmpfile.name, input_sample, export_params=True)
2020-09-21 15:17:59 +00:00
os.path.isfile(tmpfile.name)
2020-09-30 12:31:16 +00:00
--------------------
2020-09-22 10:00:54 +00:00
Using CPUs/GPUs/TPUs
2020-10-11 17:12:19 +00:00
====================
2020-10-03 12:15:07 +00:00
It's trivial to use CPUs, GPUs or TPUs in Lightning. There's **NO NEED** to change your code, simply change the :class: `~pytorch_lightning.trainer.Trainer` options.
2020-09-22 10:00:54 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-09-22 10:00:54 +00:00
# train on CPU
2021-01-26 09:44:54 +00:00
trainer = Trainer()
2020-09-22 10:00:54 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-09-22 10:00:54 +00:00
# train on 8 CPUs
2021-01-26 09:44:54 +00:00
trainer = Trainer(num_processes=8)
2020-09-22 10:00:54 +00:00
.. code-block :: python
# train on 1024 CPUs across 128 machines
trainer = pl.Trainer(
num_processes=8,
num_nodes=128
)
.. code-block :: python
# train on 1 GPU
trainer = pl.Trainer(gpus=1)
2021-01-07 05:24:47 +00:00
2020-09-22 10:00:54 +00:00
.. code-block :: python
# train on multiple GPUs across nodes (32 gpus here)
trainer = pl.Trainer(
gpus=4,
num_nodes=8
)
2021-01-07 05:24:47 +00:00
2020-09-22 10:00:54 +00:00
.. code-block :: python
# train on gpu 1, 3, 5 (3 gpus total)
trainer = pl.Trainer(gpus=[1, 3, 5])
.. code-block :: python
# Multi GPU with mixed precision
trainer = pl.Trainer(gpus=2, precision=16)
.. code-block :: python
# Train on TPUs
trainer = pl.Trainer(tpu_cores=8)
Without changing a SINGLE line of your code, you can now do the following with the above code:
.. code-block :: python
2020-10-07 17:01:50 +00:00
# train on TPUs using 16 bit precision
2020-09-22 10:00:54 +00:00
# using only half the training data and checking validation every quarter of a training epoch
trainer = pl.Trainer(
tpu_cores=8,
precision=16,
limit_train_batches=0.5,
val_check_interval=0.25
)
2021-01-07 05:24:47 +00:00
2020-09-30 12:31:16 +00:00
-----------
2020-08-30 13:31:36 +00:00
2020-08-30 15:01:16 +00:00
Checkpoints
2020-10-11 17:12:19 +00:00
===========
2020-09-22 10:00:54 +00:00
Lightning automatically saves your model. Once you've trained, you can load the checkpoints as follows:
2020-08-20 01:22:39 +00:00
2020-08-30 15:01:16 +00:00
.. code-block :: python
2020-08-30 10:51:34 +00:00
2020-08-30 15:01:16 +00:00
model = LitModel.load_from_checkpoint(path)
2020-08-30 10:51:34 +00:00
2020-09-22 10:00:54 +00:00
The above checkpoint contains all the arguments needed to init the model and set the state dict.
2020-08-30 15:01:16 +00:00
If you prefer to do it manually, here's the equivalent
2020-08-30 10:51:34 +00:00
2020-08-30 15:01:16 +00:00
.. code-block :: python
2020-08-30 10:51:34 +00:00
2020-08-30 15:01:16 +00:00
# load the ckpt
ckpt = torch.load('path/to/checkpoint.ckpt')
2020-08-20 01:22:39 +00:00
2020-08-30 15:01:16 +00:00
# equivalent to the above
model = LitModel()
model.load_state_dict(ckpt['state_dict'])
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
---------
2020-08-30 13:31:36 +00:00
2020-09-30 12:31:16 +00:00
Data flow
2020-10-11 17:12:19 +00:00
=========
2020-09-30 12:31:16 +00:00
Each loop (training, validation, test) has three hooks you can implement:
2020-10-03 12:15:07 +00:00
2020-09-30 12:31:16 +00:00
- x_step
- x_step_end
- x_epoch_end
2020-09-30 03:44:27 +00:00
2020-09-30 12:31:16 +00:00
To illustrate how data flows, we'll use the training loop (ie: x=training)
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
.. code-block :: python
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
outs = []
for batch in data:
out = training_step(batch)
outs.append(out)
training_epoch_end(outs)
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
The equivalent in Lightning is:
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
.. code-block :: python
2020-09-21 15:17:59 +00:00
2020-09-30 12:31:16 +00:00
def training_step(self, batch, batch_idx):
prediction = ...
return prediction
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
def training_epoch_end(self, training_step_outputs):
for prediction in predictions:
# do something with these
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
In the event that you use DP or DDP2 distributed modes (ie: split a batch across GPUs),
use the x_step_end to manually aggregate (or don't implement it to let lightning auto-aggregate for you).
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
.. code-block :: python
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
for batch in data:
model_copies = copy_model_per_gpu(model, num_gpus)
batch_split = split_batch_per_gpu(batch, num_gpus)
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
gpu_outs = []
for model, batch_part in zip(model_copies, batch_split):
# LightningModule hook
gpu_out = model.training_step(batch_part)
gpu_outs.append(gpu_out)
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
# LightningModule hook
out = training_step_end(gpu_outs)
2020-09-30 03:44:27 +00:00
2020-09-30 12:31:16 +00:00
The lightning equivalent is:
2020-09-30 03:44:27 +00:00
2020-09-30 12:31:16 +00:00
.. code-block :: python
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
def training_step(self, batch, batch_idx):
loss = ...
return loss
2020-09-21 15:17:59 +00:00
2020-09-30 12:31:16 +00:00
def training_step_end(self, losses):
gpu_0_loss = losses[0]
gpu_1_loss = losses[1]
return (gpu_0_loss + gpu_1_loss) * 1/2
2020-09-21 15:17:59 +00:00
2020-10-03 12:15:07 +00:00
.. tip :: The validation and test loops have the same structure.
2020-09-21 15:17:59 +00:00
2020-09-30 12:31:16 +00:00
-----------------
2020-09-30 03:44:27 +00:00
2020-09-30 12:31:16 +00:00
Logging
2020-10-11 17:12:19 +00:00
=======
2020-09-30 12:31:16 +00:00
To log to Tensorboard, your favorite logger, and/or the progress bar, use the
:func: `~~pytorch_lightning.core.lightning.LightningModule.log` method which can be called from
any method in the LightningModule.
2020-09-30 03:44:27 +00:00
2020-09-30 12:31:16 +00:00
.. code-block :: python
2020-09-30 03:44:27 +00:00
2020-09-30 12:31:16 +00:00
def training_step(self, batch, batch_idx):
self.log('my_metric', x)
2020-09-30 03:44:27 +00:00
2020-09-30 12:31:16 +00:00
The :func: `~~pytorch_lightning.core.lightning.LightningModule.log` method has a few options:
2020-09-21 15:17:59 +00:00
2020-09-30 12:31:16 +00:00
- on_step (logs the metric at that step in training)
- on_epoch (automatically accumulates and logs at the end of the epoch)
- prog_bar (logs to the progress bar)
- logger (logs to the logger like Tensorboard)
2020-08-20 01:22:39 +00:00
2020-11-30 13:29:49 +00:00
Depending on where the log is called from, Lightning auto-determines the correct mode for you. But of course
2020-09-30 12:31:16 +00:00
you can override the default behavior by manually setting the flags
2020-09-22 10:00:54 +00:00
2020-09-30 12:31:16 +00:00
.. note :: Setting on_epoch=True will accumulate your logged values over the full training epoch.
.. code-block :: python
def training_step(self, batch, batch_idx):
self.log('my_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
2020-10-03 12:15:07 +00:00
.. note ::
The loss value shown in the progress bar is smoothed (averaged) over the last values,
2020-11-30 13:29:49 +00:00
so it differs from the actual loss returned in the train/validation step.
2020-10-03 12:15:07 +00:00
2020-09-30 12:31:16 +00:00
You can also use any method of your logger directly:
.. code-block :: python
def training_step(self, batch, batch_idx):
tensorboard = self.logger.experiment
tensorboard.any_summary_writer_method_you_want())
Once your training starts, you can view the logs by using your favorite logger or booting up the Tensorboard logs:
.. code-block :: bash
tensorboard --logdir ./lightning_logs
2020-09-22 10:00:54 +00:00
2020-10-03 12:15:07 +00:00
.. note ::
Lightning automatically shows the loss value returned from `` training_step `` in the progress bar.
So, no need to explicitly log like this `` self.log('loss', loss, prog_bar=True) `` .
2020-09-30 03:44:27 +00:00
Read more about :ref: `loggers` .
2020-09-22 10:00:54 +00:00
2020-09-30 12:31:16 +00:00
----------------
2020-10-11 17:12:19 +00:00
Optional extensions
===================
2020-09-21 15:17:59 +00:00
Callbacks
2020-10-11 17:12:19 +00:00
---------
2020-09-21 15:17:59 +00:00
A callback is an arbitrary self-contained program that can be executed at arbitrary parts of the training loop.
Here's an example adding a not-so-fancy learning rate decay rule:
2020-11-30 13:29:49 +00:00
.. testcode ::
2020-09-21 15:17:59 +00:00
2021-01-26 09:44:54 +00:00
from pytorch_lightning.callbacks import Callback
class DecayLearningRate(Callback):
2020-09-21 15:17:59 +00:00
def __init__(self):
self.old_lrs = []
def on_train_start(self, trainer, pl_module):
# track the initial learning rates
2020-11-30 13:29:49 +00:00
for opt_idx, optimizer in enumerate(trainer.optimizers):
group = [param_group['lr'] for param_group in optimizer.param_groups]
2020-09-21 15:17:59 +00:00
self.old_lrs.append(group)
2020-10-08 02:27:36 +00:00
def on_train_epoch_end(self, trainer, pl_module, outputs):
2020-11-30 13:29:49 +00:00
for opt_idx, optimizer in enumerate(trainer.optimizers):
2020-09-21 15:17:59 +00:00
old_lr_group = self.old_lrs[opt_idx]
new_lr_group = []
for p_idx, param_group in enumerate(optimizer.param_groups):
old_lr = old_lr_group[p_idx]
new_lr = old_lr * 0.98
new_lr_group.append(new_lr)
param_group['lr'] = new_lr
2020-11-30 13:29:49 +00:00
self.old_lrs[opt_idx] = new_lr_group
2020-11-07 20:41:02 +00:00
2021-01-26 09:44:54 +00:00
# And pass the callback to the Trainer
2020-11-07 20:41:02 +00:00
decay_callback = DecayLearningRate()
trainer = Trainer(callbacks=[decay_callback])
2020-09-22 10:00:54 +00:00
Things you can do with a callback:
- Send emails at some point in training
- Grow the model
- Update learning rates
- Visualize gradients
- ...
- You are only limited by your imagination
:ref: `Learn more about custom callbacks <callbacks>` .
2020-08-20 01:22:39 +00:00
2020-09-30 12:31:16 +00:00
LightningDataModules
2020-10-11 17:12:19 +00:00
--------------------
2020-09-22 10:00:54 +00:00
DataLoaders and data processing code tends to end up scattered around.
Make your data code reusable by organizing it into a :class: `~pytorch_lightning.core.datamodule.LightningDataModule` .
2020-08-20 01:22:39 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-08-20 01:22:39 +00:00
2021-01-26 09:44:54 +00:00
class MNISTDataModule(LightningDataModule):
2020-08-20 01:22:39 +00:00
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
2020-08-30 13:31:36 +00:00
2020-08-20 01:22:39 +00:00
# When doing distributed training, Datamodules have two optional arguments for
2020-08-30 13:31:36 +00:00
# granular control over download/prepare/splitting data:
2020-08-20 01:22:39 +00:00
# OPTIONAL, called only on 1 GPU/machine
def prepare_data(self):
MNIST(os.getcwd(), train=True, download=True)
MNIST(os.getcwd(), train=False, download=True)
# OPTIONAL, called for every GPU/machine (assigning state is OK)
def setup(self, stage):
# transforms
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# split dataset
if stage == 'fit':
mnist_train = MNIST(os.getcwd(), train=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
if stage == 'test':
2020-08-22 09:40:42 +00:00
self.mnist_test = MNIST(os.getcwd(), train=False, transform=transform)
2020-08-20 01:22:39 +00:00
# return the dataloader for each split
def train_dataloader(self):
mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
return mnist_train
def val_dataloader(self):
mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
return mnist_val
2020-08-30 13:31:36 +00:00
2020-08-20 01:22:39 +00:00
def test_dataloader(self):
2020-08-22 09:40:42 +00:00
mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
2020-08-20 01:22:39 +00:00
return mnist_test
:class: `~pytorch_lightning.core.datamodule.LightningDataModule` is designed to enable sharing and reusing data splits
and transforms across different projects. It encapsulates all the steps needed to process data: downloading,
2020-08-22 09:40:42 +00:00
tokenizing, processing etc.
2020-08-20 01:22:39 +00:00
Now you can simply pass your :class: `~pytorch_lightning.core.datamodule.LightningDataModule` to
the :class: `~pytorch_lightning.trainer.Trainer` :
2021-01-26 09:44:54 +00:00
.. code-block :: python
2020-08-20 01:22:39 +00:00
# init model
model = LitModel()
# init data
dm = MNISTDataModule()
# train
trainer = pl.Trainer()
trainer.fit(model, dm)
# test
trainer.test(datamodule=dm)
2020-09-14 01:04:21 +00:00
DataModules are specifically useful for building models based on data. Read more on :ref: `datamodules` .
2020-08-20 01:22:39 +00:00
2020-08-30 15:01:16 +00:00
------
Debugging
2020-10-11 17:12:19 +00:00
=========
2020-09-22 10:00:54 +00:00
Lightning has many tools for debugging. Here is an example of just a few of them:
2020-08-30 13:31:36 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-08-30 13:31:36 +00:00
2020-08-30 15:01:16 +00:00
# use only 10 train batches and 3 val batches
2021-01-26 09:44:54 +00:00
trainer = Trainer(limit_train_batches=10, limit_val_batches=3)
2020-08-30 15:01:16 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-08-30 15:01:16 +00:00
2021-01-07 05:24:47 +00:00
# Automatically overfit the sane batch of your model for a sanity test
2021-01-26 09:44:54 +00:00
trainer = Trainer(overfit_batches=1)
2020-08-30 15:01:16 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-08-30 15:01:16 +00:00
2020-09-22 10:00:54 +00:00
# unit test all the code- hits every line of your code once to see if you have bugs,
# instead of waiting hours to crash on validation
2021-01-26 09:44:54 +00:00
trainer = Trainer(fast_dev_run=True)
2020-08-20 01:22:39 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2021-01-07 05:24:47 +00:00
2020-09-22 10:00:54 +00:00
# train only 20% of an epoch
2021-01-26 09:44:54 +00:00
trainer = Trainer(limit_train_batches=0.2)
2020-08-20 01:22:39 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-08-20 01:22:39 +00:00
2020-10-08 11:05:26 +00:00
# run validation every 25% of a training epoch
2021-01-26 09:44:54 +00:00
trainer = Trainer(val_check_interval=0.25)
2020-08-20 01:22:39 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2021-01-07 05:24:47 +00:00
2020-09-22 10:00:54 +00:00
# Profile your code to find speed/memory bottlenecks
Trainer(profiler=True)
2021-01-07 05:24:47 +00:00
2020-08-20 01:22:39 +00:00
---------------
2020-10-12 20:48:07 +00:00
***** ***** ***** *****
2020-10-11 17:12:19 +00:00
Other coool features
2020-10-12 20:48:07 +00:00
***** ***** ***** *****
2020-08-20 01:22:39 +00:00
Once you define and train your first Lightning model, you might want to try other cool features like
2020-09-14 01:04:21 +00:00
- :ref: `Automatic early stopping <early_stopping>`
- :ref: `Automatic truncated-back-propagation-through-time <trainer:truncated_bptt_steps>`
- :ref: `Automatically scale your batch size <training_tricks:Auto scaling of batch size>`
- :ref: `Automatically find a good learning rate <lr_finder>`
- :ref: `Load checkpoints directly from S3 <weights_loading:Checkpoint Loading>`
- :ref: `Scale to massive compute clusters <slurm>`
- :ref: `Use multiple dataloaders per train/val/test loop <multiple_loaders>`
2020-09-22 10:00:54 +00:00
- :ref: `Use multiple optimizers to do reinforcement learning or even GANs <optimizers:Use multiple optimizers (like GANs)>`
2020-09-14 01:04:21 +00:00
Or read our :ref: `introduction_guide` to learn more!
2020-08-20 01:22:39 +00:00
-------------
2020-10-11 17:30:25 +00:00
Grid AI
=======
Grid AI is our native solution for large scale training and tuning on the cloud provider of your choice.
`Click here to request early-access <https://www.grid.ai/> `_ .
------------
2020-09-22 10:00:54 +00:00
***** *****
2020-09-30 12:31:16 +00:00
Community
2020-09-22 10:00:54 +00:00
***** *****
2020-10-03 12:15:07 +00:00
Our community of core maintainers and thousands of expert researchers is active on our
`Slack <https://join.slack.com/t/pytorch-lightning/shared_invite/zt-f6bl2l0l-JYMK3tbAgAmGRrlNr00f1A> `_
and `Forum <https://forums.pytorchlightning.ai/> `_ . Drop by to hang out, ask Lightning questions or even discuss research!
2020-09-22 10:00:54 +00:00
2020-10-11 17:30:25 +00:00
-------------
2020-08-20 01:22:39 +00:00
Masterclass
===========
2020-09-30 12:31:16 +00:00
We also offer a Masterclass to teach you the advanced uses of Lightning.
2020-08-20 01:22:39 +00:00
.. image :: _images/general/PTL101_youtube_thumbnail.jpg
:width: 500
:align: center
:alt: Masterclass
:target: https://www.youtube.com/playlist?list=PLaMu-SDt_RB5NUm67hU2pdE75j6KaIOv2