1047 lines
28 KiB
ReStructuredText
1047 lines
28 KiB
ReStructuredText
.. testsetup:: *
|
|
|
|
from pytorch_lightning.core.lightning import LightningModule
|
|
from pytorch_lightning.trainer.trainer import Trainer
|
|
|
|
.. _introduction-guide:
|
|
|
|
Step-by-step walk-through
|
|
=========================
|
|
PyTorch Lightning provides a very simple template for organizing your PyTorch code. Once
|
|
you've organized it into a LightningModule, it automates most of the training for you.
|
|
|
|
In this guide, we will walk through the API by looking at how you would organize your PyTorch
|
|
code to work with Lightning.
|
|
|
|
.. raw:: html
|
|
|
|
<video width="100%" controls autoplay src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/pl_quick_start_full.m4v"></video>
|
|
|
|
|
|
|
|
|
By doing this refactor you'll:
|
|
|
|
- Make your code more reusable.
|
|
- You will not lose any flexibility.
|
|
- You'll gain free features like 16-bit precision, distributed training etc... While it may be overkill
|
|
for small projects, you won't get bogged down with engineering as your project grows in complexity.
|
|
|
|
----------------
|
|
|
|
Goal of this guide
|
|
------------------
|
|
This guide walks through the major parts of the library to help you understand
|
|
what each part does. But at the end of the day, you write the same PyTorch code... just organize it
|
|
into the LightningModule template which means you keep ALL the flexibility without having to deal with
|
|
any of the boilerplate code
|
|
|
|
To show how Lightning works, we'll start with an MNIST classifier. We'll end showing how
|
|
to use inheritance to very quickly create an AutoEncoder.
|
|
|
|
.. note:: Any DL/ML PyTorch project fits into the Lightning structure. Here we just focus on 3 types
|
|
of research to illustrate.
|
|
|
|
----------------
|
|
|
|
Installing Lightning
|
|
--------------------
|
|
Lightning is trivial to install.
|
|
|
|
.. code-block:: bash
|
|
|
|
conda activate my_env
|
|
pip install pytorch-lightning
|
|
|
|
Or without conda environments, anywhere you can use pip.
|
|
|
|
.. code-block:: bash
|
|
|
|
pip install pytorch-lightning
|
|
|
|
Or with conda
|
|
|
|
.. code-block:: bash
|
|
|
|
conda install pytorch-lightning -c conda-forge
|
|
|
|
----------------
|
|
|
|
Lightning Philosophy
|
|
--------------------
|
|
Lightning factors DL/ML code into three types:
|
|
|
|
- Research code
|
|
- Engineering code
|
|
- Non-essential code
|
|
|
|
Research code
|
|
^^^^^^^^^^^^^
|
|
In the MNIST generation example, the research code would be the particular system and how it's trained (ie: A GAN or VAE).
|
|
In Lightning, this code is abstracted out by the `LightningModule`.
|
|
|
|
.. code-block:: python
|
|
|
|
l1 = nn.Linear(...)
|
|
l2 = nn.Linear(...)
|
|
decoder = Decoder()
|
|
|
|
x1 = l1(x)
|
|
x2 = l2(x2)
|
|
out = decoder(features, x)
|
|
|
|
loss = perceptual_loss(x1, x2, x) + CE(out, x)
|
|
|
|
Engineering code
|
|
^^^^^^^^^^^^^^^^
|
|
|
|
The Engineering code is all the code related to training this system. Things such as early stopping, distribution
|
|
over GPUs, 16-bit precision, etc. This is normally code that is THE SAME across most projects.
|
|
|
|
In Lightning, this code is abstracted out by the `Trainer`.
|
|
|
|
.. code-block:: python
|
|
|
|
model.cuda(0)
|
|
x = x.cuda(0)
|
|
|
|
distributed = DistributedParallel(model)
|
|
|
|
with gpu_zero:
|
|
download_data()
|
|
|
|
dist.barrier()
|
|
|
|
Non-essential code
|
|
^^^^^^^^^^^^^^^^^^
|
|
This is code that helps the research but isn't relevant to the research code. Some examples might be:
|
|
1. Inspect gradients
|
|
2. Log to tensorboard.
|
|
|
|
In Lightning this code is abstracted out by `Callbacks`.
|
|
|
|
.. code-block:: python
|
|
|
|
# log samples
|
|
z = Q.rsample()
|
|
generated = decoder(z)
|
|
self.experiment.log('images', generated)
|
|
|
|
----------------
|
|
|
|
Elements of a research project
|
|
------------------------------
|
|
Every research project requires the same core ingredients:
|
|
|
|
1. A model
|
|
2. Train/val/test data
|
|
3. Optimizer(s)
|
|
4. Training step computations
|
|
5. Validation step computations
|
|
6. Test step computations
|
|
|
|
|
|
The Model
|
|
^^^^^^^^^
|
|
The LightningModule provides the structure on how to organize these 5 ingredients.
|
|
|
|
Let's first start with the model. In this case we'll design
|
|
a 3-layer neural network.
|
|
|
|
.. testcode::
|
|
|
|
import torch
|
|
from torch.nn import functional as F
|
|
from torch import nn
|
|
from pytorch_lightning.core.lightning import LightningModule
|
|
|
|
class LitMNIST(LightningModule):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
# mnist images are (1, 28, 28) (channels, width, height)
|
|
self.layer_1 = torch.nn.Linear(28 * 28, 128)
|
|
self.layer_2 = torch.nn.Linear(128, 256)
|
|
self.layer_3 = torch.nn.Linear(256, 10)
|
|
|
|
def forward(self, x):
|
|
batch_size, channels, width, height = x.size()
|
|
|
|
# (b, 1, 28, 28) -> (b, 1*28*28)
|
|
x = x.view(batch_size, -1)
|
|
|
|
# layer 1
|
|
x = self.layer_1(x)
|
|
x = torch.relu(x)
|
|
|
|
# layer 2
|
|
x = self.layer_2(x)
|
|
x = torch.relu(x)
|
|
|
|
# layer 3
|
|
x = self.layer_3(x)
|
|
|
|
# probability distribution over labels
|
|
x = torch.log_softmax(x, dim=1)
|
|
|
|
return x
|
|
|
|
Notice this is a `LightningModule` instead of a `torch.nn.Module`. A LightningModule is
|
|
equivalent to a PyTorch Module except it has added functionality. However, you can use it
|
|
EXACTLY the same as you would a PyTorch Module.
|
|
|
|
.. testcode::
|
|
|
|
net = LitMNIST()
|
|
x = torch.Tensor(1, 1, 28, 28)
|
|
out = net(x)
|
|
|
|
.. rst-class:: sphx-glr-script-out
|
|
|
|
Out:
|
|
|
|
.. code-block:: python
|
|
|
|
torch.Size([1, 10])
|
|
|
|
Data
|
|
^^^^
|
|
|
|
Lightning operates on pure dataloaders. Here's the PyTorch code for loading MNIST.
|
|
|
|
.. testcode::
|
|
:skipif: not TORCHVISION_AVAILABLE
|
|
|
|
from torch.utils.data import DataLoader, random_split
|
|
from torchvision.datasets import MNIST
|
|
import os
|
|
from torchvision import datasets, transforms
|
|
|
|
# transforms
|
|
# prepare transforms standard to MNIST
|
|
transform=transforms.Compose([transforms.ToTensor(),
|
|
transforms.Normalize((0.1307,), (0.3081,))])
|
|
|
|
# data
|
|
mnist_train = MNIST(os.getcwd(), train=True, download=True)
|
|
mnist_train = DataLoader(mnist_train, batch_size=64)
|
|
|
|
.. testoutput::
|
|
:hide:
|
|
:skipif: os.path.isdir(os.path.join(os.getcwd(), 'MNIST')) or not TORCHVISION_AVAILABLE
|
|
|
|
Downloading ...
|
|
Extracting ...
|
|
Downloading ...
|
|
Extracting ...
|
|
Downloading ...
|
|
Extracting ...
|
|
Processing...
|
|
Done!
|
|
|
|
There's nothing special you need to do with PyTorch Lightning! Just pass in the dataloaders to the `.fit()` function.
|
|
|
|
.. code-block:: python
|
|
|
|
model = LitMNIST()
|
|
trainer = Trainer()
|
|
trainer.fit(model, mnist_train)
|
|
|
|
DataModules
|
|
***********
|
|
Defining free-floating dataloaders, splits, download instructions and such can get messy.
|
|
In this case, it's better to group the full definition of a dataset into a `DataModule` which includes:
|
|
|
|
- Download instructions
|
|
- Processing instructions
|
|
- Split instructions
|
|
- Train dataloader
|
|
- Val dataloader(s)
|
|
- Test dataloader(s)
|
|
|
|
.. code-block:: python
|
|
|
|
class MyDataModule(pl.DataModule):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.train_dims = None
|
|
self.vocab_size = 0
|
|
|
|
def prepare_data(self):
|
|
# called only on 1 GPU
|
|
download_dataset()
|
|
tokenize()
|
|
build_vocab()
|
|
|
|
def setup(self):
|
|
# called on every GPU
|
|
vocab = load_vocab
|
|
self.vocab_size = len(vocab)
|
|
|
|
self.train, self.val, self.test = load_datasets()
|
|
self.train_dims = self.train.next_batch.size()
|
|
|
|
def train_dataloader(self):
|
|
transforms = ...
|
|
return DataLoader(self.train, transforms)
|
|
|
|
def val_dataloader(self):
|
|
transforms = ...
|
|
return DataLoader(self.val, transforms)
|
|
|
|
def test_dataloader(self):
|
|
transforms = ...
|
|
return DataLoader(self.test, transforms)
|
|
|
|
Using DataModules allows easier sharing of full dataset definitions.
|
|
|
|
.. code-block:: python
|
|
|
|
# use an MNIST dataset
|
|
mnist_dm = MNISTDatamodule()
|
|
model = LitModel(num_classes=mnist_dm.num_classes)
|
|
trainer.fit(model, mnist_dm)
|
|
|
|
# or other datasets with the same model
|
|
imagenet_dm = ImagenetDatamodule()
|
|
model = LitModel(num_classes=imagenet_dm.num_classes)
|
|
trainer.fit(model, imagenet_dm)
|
|
|
|
.. note:: `prepare_data` is called only one 1 GPU in distributed training (automatically)
|
|
.. note:: `setup` is called on every GPU (automatically)
|
|
|
|
Models defined by data
|
|
**********************
|
|
When your models need to know about the data, it's best to process the data before passing it to the model.
|
|
|
|
.. code-block:: python
|
|
|
|
# init dm AND call the processing manually
|
|
dm = ImagenetDataModule()
|
|
dm.prepare_data()
|
|
dm.setup()
|
|
|
|
model = LitModel(out_features=dm.num_classes, img_width=dm.img_width, img_height=dm.img_height)
|
|
trainer.fit(model)
|
|
|
|
|
|
1. use `prepare_data` to download and process the dataset.
|
|
2. use `setup` to do splits, and build your model internals
|
|
|
|
|
|
|
|
|
.. testcode::
|
|
|
|
class LitMNIST(LightningModule):
|
|
|
|
def __init__(self):
|
|
self.l1 = None
|
|
|
|
def prepare_data(self):
|
|
download_data()
|
|
tokenize()
|
|
|
|
def setup(self, step):
|
|
# step is either 'fit' or 'test' 90% of the time not relevant
|
|
data = load_data()
|
|
num_classes = data.classes
|
|
self.l1 = nn.Linear(..., num_classes)
|
|
|
|
Optimizer
|
|
^^^^^^^^^
|
|
|
|
Next we choose what optimizer to use for training our system.
|
|
In PyTorch we do it as follows:
|
|
|
|
.. code-block:: python
|
|
|
|
from torch.optim import Adam
|
|
optimizer = Adam(LitMNIST().parameters(), lr=1e-3)
|
|
|
|
|
|
In Lightning we do the same but organize it under the configure_optimizers method.
|
|
|
|
.. testcode::
|
|
|
|
class LitMNIST(LightningModule):
|
|
|
|
def configure_optimizers(self):
|
|
return Adam(self.parameters(), lr=1e-3)
|
|
|
|
.. note:: The LightningModule itself has the parameters, so pass in self.parameters()
|
|
|
|
However, if you have multiple optimizers use the matching parameters
|
|
|
|
.. testcode::
|
|
|
|
class LitMNIST(LightningModule):
|
|
|
|
def configure_optimizers(self):
|
|
return Adam(self.generator(), lr=1e-3), Adam(self.discriminator(), lr=1e-3)
|
|
|
|
Training step
|
|
^^^^^^^^^^^^^
|
|
|
|
The training step is what happens inside the training loop.
|
|
|
|
.. code-block:: python
|
|
|
|
for epoch in epochs:
|
|
for batch in data:
|
|
# TRAINING STEP
|
|
# ....
|
|
# TRAINING STEP
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
In the case of MNIST we do the following
|
|
|
|
.. code-block:: python
|
|
|
|
for epoch in epochs:
|
|
for batch in data:
|
|
# ------ TRAINING STEP START ------
|
|
x, y = batch
|
|
logits = model(x)
|
|
loss = F.nll_loss(logits, y)
|
|
# ------ TRAINING STEP END ------
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
In Lightning, everything that is in the training step gets organized under the `training_step` function
|
|
in the LightningModule
|
|
|
|
.. testcode::
|
|
|
|
class LitMNIST(LightningModule):
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
logits = self(x)
|
|
loss = F.nll_loss(logits, y)
|
|
return loss
|
|
|
|
Again, this is the same PyTorch code except that it has been organized by the LightningModule.
|
|
This code is not restricted which means it can be as complicated as a full seq-2-seq, RL loop, GAN, etc...
|
|
|
|
TrainResult
|
|
^^^^^^^^^^^
|
|
Whenever you'd like to log, or sync values across GPUs use `TrainResult`.
|
|
|
|
- log to Tensorboard or the other logger of your choice.
|
|
- log to the progress-bar.
|
|
- log on every step.
|
|
- log aggregate epoch metrics.
|
|
- average values across GPUs/TPU cores
|
|
|
|
.. code-block:: python
|
|
|
|
def training_step(...):
|
|
return loss
|
|
|
|
# equivalent
|
|
return pl.TrainResult(loss)
|
|
|
|
# log a metric
|
|
result = pl.TrainResult(loss)
|
|
result.log('train_loss', loss)
|
|
|
|
# equivalent
|
|
result.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=False, logger=True, reduce_fx=torch.mean)
|
|
|
|
When training across accelerators (GPUs/TPUs) you can sync a metric if needed.
|
|
|
|
.. code-block:: python
|
|
|
|
# sync across GPUs / TPUs, etc...
|
|
result.log('train_loss', loss, sync_dist=True)
|
|
|
|
If you are only using a training_loop (`training_step`) without a
|
|
validation or test loop (`validation_step`, `test_step`), you can still use EarlyStopping or automatic checkpointing
|
|
|
|
.. code-block:: python
|
|
|
|
result = pl.TrainResult(loss, checkpoint_on=loss, early_stop_on=loss)
|
|
return result
|
|
|
|
----------------
|
|
|
|
Training
|
|
--------
|
|
So far we defined 4 key ingredients in pure PyTorch but organized the code with the LightningModule.
|
|
|
|
1. Model.
|
|
2. Training data.
|
|
3. Optimizer.
|
|
4. What happens in the training loop.
|
|
|
|
|
|
|
|
|
For clarity, we'll recall that the full LightningModule now looks like this.
|
|
|
|
.. code-block:: python
|
|
|
|
class LitMNIST(LightningModule):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer_1 = torch.nn.Linear(28 * 28, 128)
|
|
self.layer_2 = torch.nn.Linear(128, 256)
|
|
self.layer_3 = torch.nn.Linear(256, 10)
|
|
|
|
def forward(self, x):
|
|
batch_size, channels, width, height = x.size()
|
|
x = x.view(batch_size, -1)
|
|
x = self.layer_1(x)
|
|
x = torch.relu(x)
|
|
x = self.layer_2(x)
|
|
x = torch.relu(x)
|
|
x = self.layer_3(x)
|
|
x = torch.log_softmax(x, dim=1)
|
|
return x
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
logits = self(x)
|
|
loss = F.nll_loss(logits, y)
|
|
|
|
# using TrainResult to enable logging
|
|
result = pl.TrainResult(loss)
|
|
result.log('train_loss', loss)
|
|
|
|
return result
|
|
|
|
Again, this is the same PyTorch code, except that it's organized
|
|
by the LightningModule. This organization now lets us train this model
|
|
|
|
Train on CPU
|
|
^^^^^^^^^^^^
|
|
|
|
.. code-block:: python
|
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
model = LitMNIST()
|
|
trainer = Trainer()
|
|
trainer.fit(model, train_loader)
|
|
|
|
You should see a weights summary and the following progress bar
|
|
|
|
.. code-block:: shell
|
|
|
|
Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10]
|
|
|
|
Logging
|
|
^^^^^^^
|
|
|
|
When we added the `TrainResult` in the return dictionary it went into the built-in tensorboard logger.
|
|
But you could have also logged by calling:
|
|
|
|
.. code-block:: python
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
# ...
|
|
loss = ...
|
|
self.logger.summary.scalar('loss', loss, step=self.global_step)
|
|
|
|
# equivalent
|
|
result = TrainResult()
|
|
result.log('loss', loss)
|
|
|
|
Which will generate automatic tensorboard logs.
|
|
|
|
.. figure:: /_images/mnist_imgs/mnist_tb.png
|
|
:alt: mnist CPU bar
|
|
:width: 500
|
|
|
|
|
|
|
|
|
But you can also use any of the `number of other loggers <loggers.rst>`_ we support.
|
|
|
|
GPU training
|
|
^^^^^^^^^^^^
|
|
|
|
But the beauty is all the magic you can do with the trainer flags. For instance, to run this model on a GPU:
|
|
|
|
.. code-block:: python
|
|
|
|
model = LitMNIST()
|
|
trainer = Trainer(gpus=1)
|
|
trainer.fit(model, train_loader)
|
|
|
|
|
|
.. figure:: /_images/mnist_imgs/mnist_gpu.png
|
|
:alt: mnist GPU bar
|
|
|
|
Multi-GPU training
|
|
^^^^^^^^^^^^^^^^^^
|
|
|
|
Or you can also train on multiple GPUs.
|
|
|
|
.. code-block:: python
|
|
|
|
model = LitMNIST()
|
|
trainer = Trainer(gpus=8)
|
|
trainer.fit(model, train_loader)
|
|
|
|
Or multiple nodes
|
|
|
|
.. code-block:: python
|
|
|
|
# (32 GPUs)
|
|
model = LitMNIST()
|
|
trainer = Trainer(gpus=8, num_nodes=4, distributed_backend='ddp')
|
|
trainer.fit(model, train_loader)
|
|
|
|
Refer to the `distributed computing guide for more details <multi_gpu.rst>`_.
|
|
|
|
TPUs
|
|
^^^^
|
|
Did you know you can use PyTorch on TPUs? It's very hard to do, but we've
|
|
worked with the xla team to use their awesome library to get this to work
|
|
out of the box!
|
|
|
|
Let's train on Colab (`full demo available here <https://colab.research.google.com/drive/1-_LKx4HwAxl5M6xPJmqAAu444LTDQoa3>`_)
|
|
|
|
First, change the runtime to TPU (and reinstall lightning).
|
|
|
|
.. figure:: /_images/mnist_imgs/runtime_tpu.png
|
|
:alt: mnist GPU bar
|
|
:width: 400
|
|
|
|
.. figure:: /_images/mnist_imgs/restart_runtime.png
|
|
:alt: mnist GPU bar
|
|
:width: 400
|
|
|
|
|
|
|
|
|
Next, install the required xla library (adds support for PyTorch on TPUs)
|
|
|
|
.. code-block:: shell
|
|
|
|
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
|
|
|
|
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
|
|
|
|
In distributed training (multiple GPUs and multiple TPU cores) each GPU or TPU core will run a copy
|
|
of this program. This means that without taking any care you will download the dataset N times which
|
|
will cause all sorts of issues.
|
|
|
|
To solve this problem, make sure your download code is in the `prepare_data` method in the DataModule.
|
|
In this method we do all the preparation we need to do once (instead of on every gpu).
|
|
|
|
`prepare_data` can be called in two ways, once per node or only on the root node
|
|
(`Trainer(prepare_data_per_node=False)`).
|
|
|
|
.. code-block:: python
|
|
|
|
class MNISTDataModule(LightningDataModule):
|
|
def __init__(self, batch_size=64):
|
|
super().__init__()
|
|
self.batch_size = batch_size
|
|
|
|
def prepare_data(self):
|
|
# download only
|
|
MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
|
|
MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
|
|
|
|
def setup(self, stage):
|
|
# transform
|
|
transform=transforms.Compose([transforms.ToTensor()])
|
|
MNIST(os.getcwd(), train=True, download=False, transform=transform)
|
|
MNIST(os.getcwd(), train=False, download=False, transform=transform)
|
|
|
|
# train/val split
|
|
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
|
|
|
|
# assign to use in dataloaders
|
|
self.train_dataset = mnist_train
|
|
self.val_dataset = mnist_val
|
|
self.test_dataset = mnist_test
|
|
|
|
def train_dataloader(self):
|
|
return DataLoader(self.train_dataset, batch_size=self.batch_size)
|
|
|
|
def val_dataloader(self):
|
|
return DataLoader(self.val_dataset, batch_size=self.batch_size)
|
|
|
|
def test_dataloader(self):
|
|
return DataLoader(self.test_dataset, batch_size=self.batch_size)
|
|
|
|
The `prepare_data` method is also a good place to do any data processing that needs to be done only
|
|
once (ie: download or tokenize, etc...).
|
|
|
|
.. note:: Lightning inserts the correct DistributedSampler for distributed training. No need to add yourself!
|
|
|
|
Now we can train the LightningModule on a TPU without doing anything else!
|
|
|
|
.. code-block:: python
|
|
|
|
dm = MNISTDataModule()
|
|
model = LitMNIST()
|
|
trainer = Trainer(tpu_cores=8)
|
|
trainer.fit(model, dm)
|
|
|
|
You'll now see the TPU cores booting up.
|
|
|
|
.. figure:: /_images/mnist_imgs/tpu_start.png
|
|
:alt: TPU start
|
|
:width: 400
|
|
|
|
Notice the epoch is MUCH faster!
|
|
|
|
.. figure:: /_images/mnist_imgs/tpu_fast.png
|
|
:alt: TPU speed
|
|
:width: 600
|
|
|
|
----------------
|
|
|
|
.. include:: hyperparameters.rst
|
|
|
|
----------------
|
|
|
|
Validating
|
|
----------
|
|
|
|
For most cases, we stop training the model when the performance on a validation
|
|
split of the data reaches a minimum.
|
|
|
|
Just like the `training_step`, we can define a `validation_step` to check whatever
|
|
metrics we care about, generate samples or add more to our logs.
|
|
|
|
Since the `validation_step` processes a single batch, use the `EvalResult` to log metrics for the full epoch.
|
|
|
|
.. code-block:: python
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
result = pl.EvalResult(checkpoint_on=loss)
|
|
result.log('val_loss', loss)
|
|
|
|
# equivalent
|
|
result.log('val_loss', loss, prog_bar=False, logger=True, on_step=False, on_epoch=True, reduce_fx=torch.mean)
|
|
return result
|
|
|
|
Now we can train with a validation loop as well.
|
|
|
|
.. code-block:: python
|
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
model = LitMNIST()
|
|
trainer = Trainer(tpu_cores=8)
|
|
trainer.fit(model, train_loader, val_loader)
|
|
|
|
You may have noticed the words `Validation sanity check` logged. This is because Lightning runs 2 batches
|
|
of validation before starting to train. This is a kind of unit test to make sure that if you have a bug
|
|
in the validation loop, you won't need to potentially wait a full epoch to find out.
|
|
|
|
.. note:: Lightning disables gradients, puts model in eval mode and does everything needed for validation.
|
|
|
|
Val loop under the hood
|
|
^^^^^^^^^^^^^^^^^^^^^^^
|
|
Under the hood, Lightning does the following:
|
|
|
|
.. code-block:: python
|
|
|
|
model = Model()
|
|
model.train()
|
|
torch.set_grad_enabled(True)
|
|
|
|
for epoch in epochs:
|
|
for batch in data:
|
|
# ...
|
|
# train
|
|
|
|
# validate
|
|
model.eval()
|
|
torch.set_grad_enabled(False)
|
|
|
|
outputs = []
|
|
for batch in val_data:
|
|
x, y = batch # validation_step
|
|
y_hat = model(x) # validation_step
|
|
loss = loss(y_hat, x) # validation_step
|
|
outputs.append({'val_loss': loss}) # validation_step
|
|
|
|
full_loss = outputs.mean() # validation_epoch_end
|
|
|
|
Optional methods
|
|
^^^^^^^^^^^^^^^^
|
|
If you still need even more fine-grain control, define the other optional methods for the loop.
|
|
|
|
.. code-block:: python
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
result = pl.EvalResult()
|
|
result.prediction = some_prediction
|
|
return result
|
|
|
|
def validation_epoch_end(self, val_step_outputs):
|
|
# do something with all the predictions from each validation_step
|
|
all_predictions = val_step_outputs.prediction
|
|
|
|
----------------
|
|
|
|
Testing
|
|
-------
|
|
Once our research is done and we're about to publish or deploy a model, we normally want to figure out
|
|
how it will generalize in the "real world." For this, we use a held-out split of the data for testing.
|
|
|
|
Just like the validation loop, we define a test loop
|
|
|
|
.. code-block:: python
|
|
|
|
class LitMNIST(LightningModule):
|
|
def test_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
logits = self(x)
|
|
loss = F.nll_loss(logits, y)
|
|
result = pl.EvalResult()
|
|
result.log('test_loss', loss)
|
|
return result
|
|
|
|
|
|
However, to make sure the test set isn't used inadvertently, Lightning has a separate API to run tests.
|
|
Once you train your model simply call `.test()`.
|
|
|
|
.. code-block:: python
|
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
model = LitMNIST()
|
|
trainer = Trainer(tpu_cores=8)
|
|
trainer.fit(model)
|
|
|
|
# run test set
|
|
result = trainer.test()
|
|
print(result)
|
|
|
|
.. rst-class:: sphx-glr-script-out
|
|
|
|
Out:
|
|
|
|
.. code-block:: none
|
|
|
|
--------------------------------------------------------------
|
|
TEST RESULTS
|
|
{'test_loss': tensor(1.1703, device='cuda:0')}
|
|
--------------------------------------------------------------
|
|
|
|
You can also run the test from a saved lightning model
|
|
|
|
.. code-block:: python
|
|
|
|
model = LitMNIST.load_from_checkpoint(PATH)
|
|
trainer = Trainer(tpu_cores=8)
|
|
trainer.test(model)
|
|
|
|
.. note:: Lightning disables gradients, puts model in eval mode and does everything needed for testing.
|
|
|
|
.. warning:: .test() is not stable yet on TPUs. We're working on getting around the multiprocessing challenges.
|
|
|
|
----------------
|
|
|
|
Predicting
|
|
----------
|
|
Again, a LightningModule is exactly the same as a PyTorch module. This means you can load it
|
|
and use it for prediction.
|
|
|
|
.. code-block:: python
|
|
|
|
model = LitMNIST.load_from_checkpoint(PATH)
|
|
x = torch.Tensor(1, 1, 28, 28)
|
|
out = model(x)
|
|
|
|
On the surface, it looks like `forward` and `training_step` are similar. Generally, we want to make sure that
|
|
what we want the model to do is what happens in the `forward`. whereas the `training_step` likely calls forward from
|
|
within it.
|
|
|
|
.. testcode::
|
|
|
|
class MNISTClassifier(LightningModule):
|
|
|
|
def forward(self, x):
|
|
batch_size, channels, width, height = x.size()
|
|
x = x.view(batch_size, -1)
|
|
x = self.layer_1(x)
|
|
x = torch.relu(x)
|
|
x = self.layer_2(x)
|
|
x = torch.relu(x)
|
|
x = self.layer_3(x)
|
|
x = torch.log_softmax(x, dim=1)
|
|
return x
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
logits = self(x)
|
|
loss = F.nll_loss(logits, y)
|
|
return loss
|
|
|
|
.. code-block:: python
|
|
|
|
model = MNISTClassifier()
|
|
x = mnist_image()
|
|
logits = model(x)
|
|
|
|
In this case, we've set this LightningModel to predict logits. But we could also have it predict feature maps:
|
|
|
|
.. testcode::
|
|
|
|
class MNISTRepresentator(LightningModule):
|
|
|
|
def forward(self, x):
|
|
batch_size, channels, width, height = x.size()
|
|
x = x.view(batch_size, -1)
|
|
x = self.layer_1(x)
|
|
x1 = torch.relu(x)
|
|
x = self.layer_2(x1)
|
|
x2 = torch.relu(x)
|
|
x3 = self.layer_3(x2)
|
|
return [x, x1, x2, x3]
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
out, l1_feats, l2_feats, l3_feats = self(x)
|
|
logits = torch.log_softmax(out, dim=1)
|
|
ce_loss = F.nll_loss(logits, y)
|
|
loss = perceptual_loss(l1_feats, l2_feats, l3_feats) + ce_loss
|
|
return loss
|
|
|
|
.. code-block:: python
|
|
|
|
model = MNISTRepresentator.load_from_checkpoint(PATH)
|
|
x = mnist_image()
|
|
feature_maps = model(x)
|
|
|
|
Or maybe we have a model that we use to do generation
|
|
|
|
.. testcode::
|
|
|
|
class LitMNISTDreamer(LightningModule):
|
|
|
|
def forward(self, z):
|
|
imgs = self.decoder(z)
|
|
return imgs
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
representation = self.encoder(x)
|
|
imgs = self(representation)
|
|
|
|
loss = perceptual_loss(imgs, x)
|
|
return loss
|
|
|
|
.. code-block:: python
|
|
|
|
model = LitMNISTDreamer.load_from_checkpoint(PATH)
|
|
z = sample_noise()
|
|
generated_imgs = model(z)
|
|
|
|
How you split up what goes in `forward` vs `training_step` depends on how you want to use this model for
|
|
prediction.
|
|
|
|
----------------
|
|
|
|
Extensibility
|
|
-------------
|
|
Although lightning makes everything super simple, it doesn't sacrifice any flexibility or control.
|
|
Lightning offers multiple ways of managing the training state.
|
|
|
|
Training overrides
|
|
^^^^^^^^^^^^^^^^^^
|
|
|
|
Any part of the training, validation and testing loop can be modified.
|
|
For instance, if you wanted to do your own backward pass, you would override the
|
|
default implementation
|
|
|
|
.. testcode::
|
|
|
|
def backward(self, use_amp, loss, optimizer):
|
|
loss.backward()
|
|
|
|
With your own
|
|
|
|
.. testcode::
|
|
|
|
class LitMNIST(LightningModule):
|
|
|
|
def backward(self, use_amp, loss, optimizer, optimizer_idx):
|
|
# do a custom way of backward
|
|
loss.backward(retain_graph=True)
|
|
|
|
Or if you wanted to initialize ddp in a different way than the default one
|
|
|
|
.. testcode::
|
|
|
|
def configure_ddp(self, model, device_ids):
|
|
# Lightning DDP simply routes to test_step, val_step, etc...
|
|
model = LightningDistributedDataParallel(
|
|
model,
|
|
device_ids=device_ids,
|
|
find_unused_parameters=True
|
|
)
|
|
return model
|
|
|
|
you could do your own:
|
|
|
|
.. testcode::
|
|
|
|
class LitMNIST(LightningModule):
|
|
|
|
def configure_ddp(self, model, device_ids):
|
|
|
|
model = Horovod(model)
|
|
# model = Ray(model)
|
|
return model
|
|
|
|
Every single part of training is configurable this way.
|
|
For a full list look at `LightningModule <lightning-module.rst>`_.
|
|
|
|
----------------
|
|
|
|
Callbacks
|
|
---------
|
|
Another way to add arbitrary functionality is to add a custom callback
|
|
for hooks that you might care about
|
|
|
|
.. testcode::
|
|
|
|
from pytorch_lightning.callbacks import Callback
|
|
|
|
class MyPrintingCallback(Callback):
|
|
|
|
def on_init_start(self, trainer):
|
|
print('Starting to init trainer!')
|
|
|
|
def on_init_end(self, trainer):
|
|
print('Trainer is init now')
|
|
|
|
def on_train_end(self, trainer, pl_module):
|
|
print('do something when training ends')
|
|
|
|
And pass the callbacks into the trainer
|
|
|
|
.. testcode::
|
|
|
|
trainer = Trainer(callbacks=[MyPrintingCallback()])
|
|
|
|
.. testoutput::
|
|
:hide:
|
|
|
|
Starting to init trainer!
|
|
Trainer is init now
|
|
|
|
.. note::
|
|
See full list of 12+ hooks in the :ref:`callbacks`.
|
|
|
|
----------------
|
|
|
|
.. include:: child_modules.rst
|
|
|
|
----------------
|
|
|
|
.. include:: transfer_learning.rst
|