2020-03-02 03:15:55 +00:00
|
|
|
Introduction Guide
|
|
|
|
==================
|
|
|
|
PyTorch Lightning provides a very simple template for organizing your PyTorch code. Once
|
|
|
|
you've organized it into a LightningModule, it automates most of the training for you.
|
|
|
|
|
|
|
|
To illustrate, here's the typical PyTorch project structure organized in a LightningModule.
|
|
|
|
|
2020-03-02 03:24:06 +00:00
|
|
|
.. figure:: /_images/mnist_imgs/pt_to_pl.jpg
|
|
|
|
:alt: Convert from PyTorch to Lightning
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
As your project grows in complexity with things like 16-bit precision, distributed training, etc... the part in blue
|
|
|
|
quickly becomes onerous and starts distracting from the core research code.
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
---------
|
|
|
|
|
2020-03-02 03:15:55 +00:00
|
|
|
Goal of this guide
|
|
|
|
------------------
|
|
|
|
This guide walks through the major parts of the library to help you understand
|
|
|
|
what each parts does. But at the end of the day, you write the same PyTorch code... just organize it
|
|
|
|
into the LightningModule template which means you keep ALL the flexibility without having to deal with
|
|
|
|
any of the boilerplate code
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
To show how Lightning works, we'll start with an MNIST classifier. We'll end showing how
|
|
|
|
to use inheritance to very quickly create an AutoEncoder.
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
.. note:: Any DL/ML PyTorch project fits into the Lightning structure. Here we just focus on 3 types
|
|
|
|
of research to illustrate.
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
---------
|
|
|
|
|
2020-03-02 03:15:55 +00:00
|
|
|
Lightning Philosophy
|
|
|
|
--------------------
|
|
|
|
Lightning factors DL/ML code into three types:
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
- Research code
|
2020-03-12 16:42:09 +00:00
|
|
|
- Engineering code
|
2020-03-03 21:42:49 +00:00
|
|
|
- Non-essential code
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
Research code
|
|
|
|
^^^^^^^^^^^^^
|
|
|
|
In the MNIST generation example, the research code would be the particular system and how it's trained (ie: A GAN or VAE).
|
|
|
|
In Lightning, this code is abstracted out by the `LightningModule`.
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
l1 = nn.Linear(...)
|
|
|
|
l2 = nn.Linear(...)
|
|
|
|
decoder = Decoder()
|
|
|
|
|
|
|
|
x1 = l1(x)
|
|
|
|
x2 = l2(x2)
|
|
|
|
out = decoder(features, x)
|
|
|
|
|
|
|
|
loss = perceptual_loss(x1, x2, x) + CE(out, x)
|
|
|
|
|
2020-03-02 03:15:55 +00:00
|
|
|
Engineering code
|
|
|
|
^^^^^^^^^^^^^^^^
|
|
|
|
|
|
|
|
The Engineering code is all the code related to training this system. Things such as early stopping, distribution
|
|
|
|
over GPUs, 16-bit precision, etc. This is normally code that is THE SAME across most projects.
|
|
|
|
|
|
|
|
In Lightning, this code is abstracted out by the `Trainer`.
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
model.cuda(0)
|
|
|
|
x = x.cuda(0)
|
|
|
|
|
|
|
|
distributed = DistributedParallel(model)
|
|
|
|
|
|
|
|
with gpu_zero:
|
|
|
|
download_data()
|
|
|
|
|
|
|
|
dist.barrier()
|
|
|
|
|
2020-03-02 03:15:55 +00:00
|
|
|
Non-essential code
|
|
|
|
^^^^^^^^^^^^^^^^^^
|
|
|
|
This is code that helps the research but isn't relevant to the research code. Some examples might be:
|
|
|
|
1. Inspect gradients
|
|
|
|
2. Log to tensorboard.
|
|
|
|
|
|
|
|
In Lightning this code is abstracted out by `Callbacks`.
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# log samples
|
|
|
|
z = Q.rsample()
|
|
|
|
generated = decoder(z)
|
|
|
|
self.experiment.log('images', generated)
|
|
|
|
|
|
|
|
---------
|
|
|
|
|
2020-03-02 03:15:55 +00:00
|
|
|
Elements of a research project
|
|
|
|
------------------------------
|
|
|
|
Every research project requires the same core ingredients:
|
|
|
|
|
|
|
|
1. A model
|
|
|
|
2. Train/val/test data
|
|
|
|
3. Optimizer(s)
|
|
|
|
4. Training step computations
|
|
|
|
5. Validation step computations
|
|
|
|
6. Test step computations
|
|
|
|
|
|
|
|
|
|
|
|
The Model
|
2020-03-02 03:35:56 +00:00
|
|
|
^^^^^^^^^
|
2020-03-02 03:15:55 +00:00
|
|
|
The LightningModule provides the structure on how to organize these 5 ingredients.
|
|
|
|
|
|
|
|
Let's first start with the model. In this case we'll design
|
|
|
|
a 3-layer neural network.
|
|
|
|
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch.nn import functional as F
|
|
|
|
from torch import nn
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
class LitMNIST(pl.LightningModule):
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
def __init__(self):
|
2020-03-06 11:25:24 +00:00
|
|
|
super(LitMNIST, self).__init__()
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
# mnist images are (1, 28, 28) (channels, width, height)
|
|
|
|
self.layer_1 = torch.nn.Linear(28 * 28, 128)
|
|
|
|
self.layer_2 = torch.nn.Linear(128, 256)
|
|
|
|
self.layer_3 = torch.nn.Linear(256, 10)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
batch_size, channels, width, height = x.size()
|
|
|
|
|
|
|
|
# (b, 1, 28, 28) -> (b, 1*28*28)
|
|
|
|
x = x.view(batch_size, -1)
|
|
|
|
|
|
|
|
# layer 1
|
|
|
|
x = self.layer_1(x)
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
|
|
|
# layer 2
|
|
|
|
x = self.layer_2(x)
|
|
|
|
x = torch.relu(x)
|
|
|
|
|
|
|
|
# layer 3
|
|
|
|
x = self.layer_3(x)
|
|
|
|
|
|
|
|
# probability distribution over labels
|
|
|
|
x = torch.log_softmax(x, dim=1)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
Notice this is a `LightningModule` instead of a `torch.nn.Module`. A LightningModule is
|
|
|
|
equivalent to a PyTorch Module except it has added functionality. However, you can use it
|
|
|
|
EXACTLY the same as you would a PyTorch Module.
|
|
|
|
|
|
|
|
.. code-block:: default
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
net = LitMNIST()
|
2020-03-02 03:15:55 +00:00
|
|
|
x = torch.Tensor(1, 1, 28, 28)
|
|
|
|
out = net(x)
|
|
|
|
|
|
|
|
.. rst-class:: sphx-glr-script-out
|
|
|
|
|
|
|
|
Out:
|
|
|
|
|
|
|
|
.. code-block:: none
|
|
|
|
|
|
|
|
torch.Size([1, 10])
|
|
|
|
|
|
|
|
Data
|
2020-03-02 03:35:56 +00:00
|
|
|
^^^^
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
The Lightning Module organizes your dataloaders and data processing as well.
|
|
|
|
Here's the PyTorch code for loading MNIST
|
|
|
|
|
|
|
|
.. code-block:: default
|
|
|
|
|
|
|
|
from torch.utils.data import DataLoader, random_split
|
|
|
|
from torchvision.datasets import MNIST
|
|
|
|
import os
|
|
|
|
from torchvision import datasets, transforms
|
|
|
|
|
|
|
|
|
|
|
|
# transforms
|
|
|
|
# prepare transforms standard to MNIST
|
2020-03-06 12:07:04 +00:00
|
|
|
transform=transforms.Compose([transforms.ToTensor(),
|
|
|
|
transforms.Normalize((0.1307,), (0.3081,))])
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
# data
|
|
|
|
mnist_train = MNIST(os.getcwd(), train=True, download=True)
|
|
|
|
mnist_train = DataLoader(mnist_train, batch_size=64)
|
|
|
|
|
|
|
|
When using PyTorch Lightning, we use the exact same code except we organize it into
|
|
|
|
the LightningModule
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
from torch.utils.data import DataLoader, random_split
|
|
|
|
from torchvision.datasets import MNIST
|
|
|
|
import os
|
|
|
|
from torchvision import datasets, transforms
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
class LitMNIST(pl.LightningModule):
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
def train_dataloader(self):
|
2020-03-06 12:07:04 +00:00
|
|
|
transform=transforms.Compose([transforms.ToTensor(),
|
|
|
|
transforms.Normalize((0.1307,), (0.3081,))])
|
|
|
|
mnist_train = MNIST(os.getcwd(), train=True, download=False,
|
|
|
|
transform=transform)
|
2020-03-02 03:15:55 +00:00
|
|
|
return DataLoader(mnist_train, batch_size=64)
|
|
|
|
|
|
|
|
Notice the code is exactly the same, except now the training dataloading has been organized by the LightningModule
|
|
|
|
under the `train_dataloader` method. This is great because if you run into a project that uses Lightning and want
|
|
|
|
to figure out how they prepare their training data you can just look in the `train_dataloader` method.
|
|
|
|
|
2020-03-06 17:12:39 +00:00
|
|
|
Usually though, we want to separate the things that write to disk in data-processing from
|
|
|
|
things like transforms which happen in memory.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
class LitMNIST(pl.LightningModule):
|
|
|
|
|
|
|
|
def prepare_data(self):
|
|
|
|
# download only
|
|
|
|
MNIST(os.getcwd(), train=True, download=True)
|
|
|
|
|
|
|
|
def train_dataloader(self):
|
|
|
|
# no download, just transform
|
|
|
|
transform=transforms.Compose([transforms.ToTensor(),
|
|
|
|
transforms.Normalize((0.1307,), (0.3081,))])
|
|
|
|
mnist_train = MNIST(os.getcwd(), train=True, download=False,
|
|
|
|
transform=transform)
|
|
|
|
return DataLoader(mnist_train, batch_size=64)
|
|
|
|
|
|
|
|
Doing it in the `prepare_data` method ensures that when you have
|
|
|
|
multiple GPUs you won't overwrite the data. This is a contrived example
|
|
|
|
but it gets more complicated with things like NLP or Imagenet.
|
|
|
|
|
|
|
|
In general fill these methods with the following:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
class LitMNIST(pl.LightningModule):
|
|
|
|
|
|
|
|
def prepare_data(self):
|
|
|
|
# stuff here is done once at the very beginning of training
|
|
|
|
# before any distributed training starts
|
|
|
|
|
|
|
|
# download stuff
|
|
|
|
# save to disk
|
|
|
|
# etc...
|
|
|
|
|
|
|
|
def train_dataloader(self):
|
|
|
|
# data transforms
|
|
|
|
# dataset creation
|
|
|
|
# return a DataLoader
|
|
|
|
|
|
|
|
|
|
|
|
|
2020-03-02 03:15:55 +00:00
|
|
|
Optimizer
|
2020-03-02 03:35:56 +00:00
|
|
|
^^^^^^^^^
|
|
|
|
|
2020-03-02 03:15:55 +00:00
|
|
|
Next we choose what optimizer to use for training our system.
|
|
|
|
In PyTorch we do it as follows:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
from torch.optim import Adam
|
2020-03-06 11:25:24 +00:00
|
|
|
optimizer = Adam(LitMNIST().parameters(), lr=1e-3)
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
|
|
|
|
In Lightning we do the same but organize it under the configure_optimizers method.
|
|
|
|
If you don't define this, Lightning will automatically use `Adam(self.parameters(), lr=1e-3)`.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
class LitMNIST(pl.LightningModule):
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
return Adam(self.parameters(), lr=1e-3)
|
|
|
|
|
|
|
|
Training step
|
2020-03-02 03:35:56 +00:00
|
|
|
^^^^^^^^^^^^^
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
The training step is what happens inside the training loop.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
for epoch in epochs:
|
|
|
|
for batch in data:
|
|
|
|
# TRAINING STEP
|
|
|
|
# ....
|
|
|
|
# TRAINING STEP
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
In the case of MNIST we do the following
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
for epoch in epochs:
|
|
|
|
for batch in data:
|
|
|
|
# TRAINING STEP START
|
|
|
|
x, y = batch
|
|
|
|
logits = model(x)
|
|
|
|
loss = F.nll_loss(logits, y)
|
|
|
|
# TRAINING STEP END
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
In Lightning, everything that is in the training step gets organized under the `training_step` function
|
|
|
|
in the LightningModule
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
class LitMNIST(pl.LightningModule):
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
2020-03-27 07:17:56 +00:00
|
|
|
logits = self(x)
|
2020-03-02 03:15:55 +00:00
|
|
|
loss = F.nll_loss(logits, y)
|
|
|
|
return {'loss': loss}
|
|
|
|
# return loss (also works)
|
|
|
|
|
|
|
|
Again, this is the same PyTorch code except that it has been organized by the LightningModule.
|
|
|
|
This code is not restricted which means it can be as complicated as a full seq-2-seq, RL loop, GAN, etc...
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
---------
|
|
|
|
|
2020-03-02 03:15:55 +00:00
|
|
|
Training
|
|
|
|
--------
|
|
|
|
So far we defined 4 key ingredients in pure PyTorch but organized the code inside the LightningModule.
|
|
|
|
|
|
|
|
1. Model.
|
|
|
|
2. Training data.
|
|
|
|
3. Optimizer.
|
|
|
|
4. What happens in the training loop.
|
|
|
|
|
|
|
|
For clarity, we'll recall that the full LightningModule now looks like this.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
class LitMNIST(pl.LightningModule):
|
2020-03-02 03:15:55 +00:00
|
|
|
def __init__(self):
|
2020-03-06 11:25:24 +00:00
|
|
|
super(LitMNIST, self).__init__()
|
2020-03-02 03:15:55 +00:00
|
|
|
self.layer_1 = torch.nn.Linear(28 * 28, 128)
|
|
|
|
self.layer_2 = torch.nn.Linear(128, 256)
|
|
|
|
self.layer_3 = torch.nn.Linear(256, 10)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
batch_size, channels, width, height = x.size()
|
|
|
|
x = x.view(batch_size, -1)
|
|
|
|
x = self.layer_1(x)
|
|
|
|
x = torch.relu(x)
|
|
|
|
x = self.layer_2(x)
|
|
|
|
x = torch.relu(x)
|
|
|
|
x = self.layer_3(x)
|
|
|
|
x = torch.log_softmax(x, dim=1)
|
|
|
|
return x
|
|
|
|
|
|
|
|
def train_dataloader(self):
|
2020-03-06 12:07:04 +00:00
|
|
|
transform=transforms.Compose([transforms.ToTensor(),
|
|
|
|
transforms.Normalize((0.1307,), (0.3081,))])
|
2020-03-02 03:15:55 +00:00
|
|
|
mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform)
|
|
|
|
return DataLoader(mnist_train, batch_size=64)
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
return Adam(self.parameters(), lr=1e-3)
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
2020-03-27 07:17:56 +00:00
|
|
|
logits = self(x)
|
2020-03-02 03:15:55 +00:00
|
|
|
loss = F.nll_loss(logits, y)
|
|
|
|
|
|
|
|
# add logging
|
|
|
|
logs = {'loss': loss}
|
|
|
|
return {'loss': loss, 'log': logs}
|
|
|
|
|
|
|
|
Again, this is the same PyTorch code, except that it's organized
|
|
|
|
by the LightningModule. This organization now lets us train this model
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
Train on CPU
|
|
|
|
^^^^^^^^^^^^
|
|
|
|
|
2020-03-02 03:15:55 +00:00
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
model = LitMNIST()
|
2020-03-02 03:15:55 +00:00
|
|
|
trainer = Trainer()
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
You should see the following weights summary and progress bar
|
|
|
|
|
2020-03-02 03:24:06 +00:00
|
|
|
.. figure:: /_images/mnist_imgs/mnist_cpu_bar.png
|
2020-03-02 03:15:55 +00:00
|
|
|
:alt: mnist CPU bar
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
Logging
|
|
|
|
^^^^^^^
|
|
|
|
|
2020-03-02 03:15:55 +00:00
|
|
|
When we added the `log` key in the return dictionary it went into the built in tensorboard logger.
|
|
|
|
But you could have also logged by calling:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
# ...
|
|
|
|
loss = ...
|
|
|
|
self.logger.summary.scalar('loss', loss)
|
|
|
|
|
|
|
|
Which will generate automatic tensorboard logs.
|
|
|
|
|
2020-03-02 03:24:06 +00:00
|
|
|
.. figure:: /_images/mnist_imgs/mnist_tb.png
|
2020-03-02 03:15:55 +00:00
|
|
|
:alt: mnist CPU bar
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
But you can also use any of the `number of other loggers <loggers.rst>`_ we support.
|
|
|
|
|
|
|
|
GPU training
|
|
|
|
^^^^^^^^^^^^
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
But the beauty is all the magic you can do with the trainer flags. For instance, to run this model on a GPU:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
model = LitMNIST()
|
2020-03-02 03:15:55 +00:00
|
|
|
trainer = Trainer(gpus=1)
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
|
2020-03-02 03:24:06 +00:00
|
|
|
.. figure:: /_images/mnist_imgs/mnist_gpu.png
|
2020-03-02 03:15:55 +00:00
|
|
|
:alt: mnist GPU bar
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
Multi-GPU training
|
|
|
|
^^^^^^^^^^^^^^^^^^
|
|
|
|
|
|
|
|
Or you can also train on multiple GPUs.
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
model = LitMNIST()
|
2020-03-02 03:15:55 +00:00
|
|
|
trainer = Trainer(gpus=8)
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
Or multiple nodes
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# (32 GPUs)
|
2020-03-06 11:25:24 +00:00
|
|
|
model = LitMNIST()
|
2020-03-02 03:15:55 +00:00
|
|
|
trainer = Trainer(gpus=8, num_nodes=4, distributed_backend='ddp')
|
|
|
|
trainer.fit(model)
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
Refer to the `distributed computing guide for more details <multi_gpu.rst>`_.
|
|
|
|
|
|
|
|
TPUs
|
|
|
|
^^^^
|
|
|
|
Did you know you can use PyTorch on TPUs? It's very hard to do, but we've
|
|
|
|
worked with the xla team to use their awesome library to get this to work
|
|
|
|
out of the box!
|
|
|
|
|
|
|
|
Let's train on Colab (`full demo available here <https://colab.research.google.com/drive/1-_LKx4HwAxl5M6xPJmqAAu444LTDQoa3>`_)
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
First, change the runtime to TPU (and reinstall lightning).
|
|
|
|
|
2020-03-02 03:24:06 +00:00
|
|
|
.. figure:: /_images/mnist_imgs/runtime_tpu.png
|
2020-03-02 03:15:55 +00:00
|
|
|
:alt: mnist GPU bar
|
|
|
|
|
2020-03-02 03:24:06 +00:00
|
|
|
.. figure:: /_images/mnist_imgs/restart_runtime.png
|
2020-03-02 03:15:55 +00:00
|
|
|
:alt: mnist GPU bar
|
|
|
|
|
|
|
|
Next, install the required xla library (adds support for PyTorch on TPUs)
|
|
|
|
|
2020-03-20 19:49:01 +00:00
|
|
|
.. code-block:: python
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
import collections
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
import os
|
|
|
|
import requests
|
|
|
|
import threading
|
|
|
|
|
|
|
|
_VersionConfig = collections.namedtuple('_VersionConfig', 'wheels,server')
|
|
|
|
VERSION = "torch_xla==nightly" #@param ["xrt==1.15.0", "torch_xla==nightly"]
|
|
|
|
CONFIG = {
|
|
|
|
'xrt==1.15.0': _VersionConfig('1.15', '1.15.0'),
|
|
|
|
'torch_xla==nightly': _VersionConfig('nightly', 'XRT-dev{}'.format(
|
|
|
|
(datetime.today() - timedelta(1)).strftime('%Y%m%d'))),
|
|
|
|
}[VERSION]
|
|
|
|
DIST_BUCKET = 'gs://tpu-pytorch/wheels'
|
|
|
|
TORCH_WHEEL = 'torch-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
|
|
|
|
TORCH_XLA_WHEEL = 'torch_xla-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
|
|
|
|
TORCHVISION_WHEEL = 'torchvision-{}-cp36-cp36m-linux_x86_64.whl'.format(CONFIG.wheels)
|
|
|
|
|
|
|
|
# Update TPU XRT version
|
|
|
|
def update_server_xrt():
|
|
|
|
print('Updating server-side XRT to {} ...'.format(CONFIG.server))
|
|
|
|
url = 'http://{TPU_ADDRESS}:8475/requestversion/{XRT_VERSION}'.format(
|
|
|
|
TPU_ADDRESS=os.environ['COLAB_TPU_ADDR'].split(':')[0],
|
|
|
|
XRT_VERSION=CONFIG.server,
|
|
|
|
)
|
|
|
|
print('Done updating server-side XRT: {}'.format(requests.post(url)))
|
|
|
|
|
|
|
|
update = threading.Thread(target=update_server_xrt)
|
|
|
|
update.start()
|
|
|
|
|
2020-03-20 19:49:01 +00:00
|
|
|
.. code-block::
|
|
|
|
|
2020-03-02 03:15:55 +00:00
|
|
|
# Install Colab TPU compat PyTorch/TPU wheels and dependencies
|
|
|
|
!pip uninstall -y torch torchvision
|
|
|
|
!gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" .
|
|
|
|
!gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" .
|
|
|
|
!gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" .
|
|
|
|
!pip install "$TORCH_WHEEL"
|
|
|
|
!pip install "$TORCH_XLA_WHEEL"
|
|
|
|
!pip install "$TORCHVISION_WHEEL"
|
|
|
|
!sudo apt-get install libomp5
|
|
|
|
update.join()
|
|
|
|
|
|
|
|
In distributed training (multiple GPUs and multiple TPU cores) each GPU or TPU core will run a copy
|
|
|
|
of this program. This means that without taking any care you will download the dataset N times which
|
|
|
|
will cause all sorts of issues.
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
To solve this problem, move the download code to the `prepare_data` method in the LightningModule.
|
|
|
|
In this method we do all the preparation we need to do once (instead of on every gpu).
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
class LitMNIST(pl.LightningModule):
|
2020-03-02 03:15:55 +00:00
|
|
|
def prepare_data(self):
|
2020-03-06 11:25:24 +00:00
|
|
|
# transform
|
|
|
|
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
|
|
|
|
|
|
|
# download
|
|
|
|
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
|
|
|
|
mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
|
|
|
|
|
|
|
|
# train/val split
|
|
|
|
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
|
|
|
|
|
|
|
|
# assign to use in dataloaders
|
|
|
|
self.train_dataset = mnist_train
|
|
|
|
self.val_dataset = mnist_val
|
|
|
|
self.test_dataset = mnist_test
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
def train_dataloader(self):
|
2020-03-06 11:33:28 +00:00
|
|
|
return DataLoader(self.train_dataset, batch_size=64)
|
2020-03-06 11:25:24 +00:00
|
|
|
|
|
|
|
def val_dataloader(self):
|
2020-03-24 18:53:15 +00:00
|
|
|
return DataLoader(self.val_dataset, batch_size=64)
|
2020-03-06 11:25:24 +00:00
|
|
|
|
|
|
|
def test_dataloader(self):
|
2020-03-24 18:53:15 +00:00
|
|
|
return DataLoader(self.test_dataset, batch_size=64)
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
The `prepare_data` method is also a good place to do any data processing that needs to be done only
|
|
|
|
once (ie: download or tokenize, etc...).
|
|
|
|
|
|
|
|
.. note:: Lightning inserts the correct DistributedSampler for distributed training. No need to add yourself!
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
Now we can train the LightningModule on a TPU without doing anything else!
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
model = LitMNIST()
|
2020-03-02 03:15:55 +00:00
|
|
|
trainer = Trainer(num_tpu_cores=8)
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
You'll now see the TPU cores booting up.
|
|
|
|
|
2020-03-02 03:24:06 +00:00
|
|
|
.. figure:: /_images/mnist_imgs/tpu_start.png
|
2020-03-02 03:15:55 +00:00
|
|
|
:alt: TPU start
|
|
|
|
|
|
|
|
Notice the epoch is MUCH faster!
|
|
|
|
|
2020-03-02 03:24:06 +00:00
|
|
|
.. figure:: /_images/mnist_imgs/tpu_fast.png
|
2020-03-02 03:15:55 +00:00
|
|
|
:alt: TPU speed
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
---------
|
|
|
|
|
2020-03-03 15:52:16 +00:00
|
|
|
Hyperparameters
|
2020-03-02 03:15:55 +00:00
|
|
|
---------------
|
2020-03-03 15:52:16 +00:00
|
|
|
Normally, we don't hard-code the values to a model. We usually use the command line to
|
2020-03-06 19:53:27 +00:00
|
|
|
modify the network.
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
from argparse import ArgumentParser
|
|
|
|
|
|
|
|
parser = ArgumentParser()
|
|
|
|
|
|
|
|
# parametrize the network
|
|
|
|
parser.add_argument('--layer_1_dim', type=int, default=128)
|
2020-03-06 19:43:17 +00:00
|
|
|
parser.add_argument('--layer_2_dim', type=int, default=256)
|
2020-03-03 15:52:16 +00:00
|
|
|
parser.add_argument('--batch_size', type=int, default=64)
|
2020-03-06 19:43:17 +00:00
|
|
|
|
2020-03-03 15:52:16 +00:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
Now we can parametrize the LightningModule.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
:emphasize-lines: 5,6,7,12,14
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
class LitMNIST(pl.LightningModule):
|
2020-03-03 15:52:16 +00:00
|
|
|
def __init__(self, hparams):
|
2020-03-06 11:25:24 +00:00
|
|
|
super(LitMNIST, self).__init__()
|
2020-03-03 15:52:16 +00:00
|
|
|
self.hparams = hparams
|
|
|
|
|
|
|
|
self.layer_1 = torch.nn.Linear(28 * 28, hparams.layer_1_dim)
|
|
|
|
self.layer_2 = torch.nn.Linear(hparams.layer_1_dim, hparams.layer_2_dim)
|
|
|
|
self.layer_3 = torch.nn.Linear(hparams.layer_2_dim, 10)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
...
|
|
|
|
|
|
|
|
def train_dataloader(self):
|
|
|
|
...
|
|
|
|
return DataLoader(mnist_train, batch_size=self.hparams.batch_size)
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
return Adam(self.parameters(), lr=self.hparams.learning_rate)
|
|
|
|
|
|
|
|
hparams = parse_args()
|
2020-03-06 11:25:24 +00:00
|
|
|
model = LitMNIST(hparams)
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
.. note:: Bonus! if (hparams) is in your module, Lightning will save it into the checkpoint and restore your
|
|
|
|
model using those hparams exactly.
|
|
|
|
|
2020-03-06 19:53:27 +00:00
|
|
|
And we can also add all the flags available in the Trainer to the Argparser.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
# add all the available Trainer options to the ArgParser
|
|
|
|
parser = pl.Trainer.add_argparse_args(parser)
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
And now you can start your program with
|
|
|
|
|
|
|
|
.. code-block:: bash
|
|
|
|
|
|
|
|
# now you can use any trainer flag
|
|
|
|
$ python main.py --num_nodes 2 --gpus 8
|
|
|
|
|
|
|
|
|
2020-03-03 15:52:16 +00:00
|
|
|
For a full guide on using hyperparameters, `check out the hyperparameters docs <hyperparameters.rst>`_.
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
---------
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
Validating
|
|
|
|
----------
|
2020-03-02 03:35:56 +00:00
|
|
|
|
2020-03-02 03:15:55 +00:00
|
|
|
For most cases, we stop training the model when the performance on a validation
|
|
|
|
split of the data reaches a minimum.
|
|
|
|
|
|
|
|
Just like the `training_step`, we can define a `validation_step` to check whatever
|
|
|
|
metrics we care about, generate samples or add more to our logs.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
for epoch in epochs:
|
|
|
|
for batch in data:
|
|
|
|
# ...
|
|
|
|
# train
|
|
|
|
|
|
|
|
# validate
|
|
|
|
outputs = []
|
|
|
|
for batch in val_data:
|
|
|
|
x, y = batch # validation_step
|
|
|
|
y_hat = model(x) # validation_step
|
|
|
|
loss = loss(y_hat, x) # validation_step
|
|
|
|
outputs.append({'val_loss': loss}) # validation_step
|
|
|
|
|
2020-03-06 17:12:39 +00:00
|
|
|
full_loss = outputs.mean() # validation_epoch_end
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
Since the `validation_step` processes a single batch,
|
2020-03-06 17:12:39 +00:00
|
|
|
in Lightning we also have a `validation_epoch_end` method which allows you to compute
|
|
|
|
statistics on the full dataset after an epoch of validation data and not just the batch.
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
In addition, we define a `val_dataloader` method which tells the trainer what data to use for validation.
|
|
|
|
Notice we split the train split of MNIST into train, validation. We also have to make sure to do the
|
|
|
|
sample split in the `train_dataloader` method.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
class LitMNIST(pl.LightningModule):
|
2020-03-02 03:15:55 +00:00
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
2020-03-27 07:17:56 +00:00
|
|
|
logits = self(x)
|
2020-03-02 03:15:55 +00:00
|
|
|
loss = F.nll_loss(logits, y)
|
|
|
|
return {'val_loss': loss}
|
|
|
|
|
2020-03-06 00:31:57 +00:00
|
|
|
def validation_epoch_end(self, outputs):
|
2020-03-02 03:15:55 +00:00
|
|
|
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
|
|
|
|
tensorboard_logs = {'val_loss': avg_loss}
|
|
|
|
return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
|
|
|
|
|
|
|
|
def val_dataloader(self):
|
2020-03-06 12:07:04 +00:00
|
|
|
transform=transforms.Compose([transforms.ToTensor(),
|
|
|
|
transforms.Normalize((0.1307,), (0.3081,))])
|
|
|
|
mnist_train = MNIST(os.getcwd(), train=True, download=False,
|
|
|
|
transform=transform)
|
2020-03-02 03:15:55 +00:00
|
|
|
_, mnist_val = random_split(mnist_train, [55000, 5000])
|
|
|
|
mnist_val = DataLoader(mnist_val, batch_size=64)
|
|
|
|
return mnist_val
|
|
|
|
|
|
|
|
Again, we've just organized the regular PyTorch code into two steps, the `validation_step` method which
|
2020-03-06 17:12:39 +00:00
|
|
|
operates on a single batch and the `validation_epoch_end` method to compute statistics on all batches.
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
If you have these methods defined, Lightning will call them automatically. Now we can train
|
|
|
|
while checking the validation set.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
model = LitMNIST()
|
2020-03-02 03:15:55 +00:00
|
|
|
trainer = Trainer(num_tpu_cores=8)
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
You may have noticed the words `Validation sanity check` logged. This is because Lightning runs 5 batches
|
|
|
|
of validation before starting to train. This is a kind of unit test to make sure that if you have a bug
|
|
|
|
in the validation loop, you won't need to potentially wait a full epoch to find out.
|
|
|
|
|
|
|
|
.. note:: Lightning disables gradients, puts model in eval mode and does everything needed for validation.
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
---------
|
|
|
|
|
2020-03-03 15:52:16 +00:00
|
|
|
Testing
|
|
|
|
-------
|
2020-03-02 03:15:55 +00:00
|
|
|
Once our research is done and we're about to publish or deploy a model, we normally want to figure out
|
|
|
|
how it will generalize in the "real world." For this, we use a held-out split of the data for testing.
|
|
|
|
|
|
|
|
Just like the validation loop, we define exactly the same steps for testing:
|
|
|
|
|
|
|
|
- test_step
|
2020-03-06 17:12:39 +00:00
|
|
|
- test_epoch_end
|
2020-03-02 03:15:55 +00:00
|
|
|
- test_dataloader
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
class LitMNIST(pl.LightningModule):
|
2020-03-02 03:15:55 +00:00
|
|
|
def test_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
2020-03-27 07:17:56 +00:00
|
|
|
logits = self(x)
|
2020-03-02 03:15:55 +00:00
|
|
|
loss = F.nll_loss(logits, y)
|
|
|
|
return {'val_loss': loss}
|
|
|
|
|
2020-03-06 00:31:57 +00:00
|
|
|
def test_epoch_end(self, outputs):
|
2020-03-02 03:15:55 +00:00
|
|
|
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
|
|
|
|
tensorboard_logs = {'val_loss': avg_loss}
|
|
|
|
return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
|
|
|
|
|
|
|
|
def test_dataloader(self):
|
|
|
|
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
|
|
|
mnist_train = MNIST(os.getcwd(), train=False, download=False, transform=transform)
|
|
|
|
_, mnist_val = random_split(mnist_train, [55000, 5000])
|
|
|
|
mnist_val = DataLoader(mnist_val, batch_size=64)
|
|
|
|
return mnist_val
|
|
|
|
|
|
|
|
However, to make sure the test set isn't used inadvertently, Lightning has a separate API to run tests.
|
|
|
|
Once you train your model simply call `.test()`.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
model = LitMNIST()
|
2020-03-02 03:15:55 +00:00
|
|
|
trainer = Trainer(num_tpu_cores=8)
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
# run test set
|
|
|
|
trainer.test()
|
|
|
|
|
2020-03-06 17:12:39 +00:00
|
|
|
.. rst-class:: sphx-glr-script-out
|
|
|
|
|
|
|
|
Out:
|
|
|
|
|
|
|
|
.. code-block:: none
|
|
|
|
|
|
|
|
--------------------------------------------------------------
|
|
|
|
TEST RESULTS
|
|
|
|
{'test_loss': tensor(1.1703, device='cuda:0')}
|
|
|
|
--------------------------------------------------------------
|
|
|
|
|
2020-03-02 03:15:55 +00:00
|
|
|
You can also run the test from a saved lightning model
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
model = LitMNIST.load_from_checkpoint(PATH)
|
2020-03-02 03:15:55 +00:00
|
|
|
trainer = Trainer(num_tpu_cores=8)
|
|
|
|
trainer.test(model)
|
|
|
|
|
|
|
|
.. note:: Lightning disables gradients, puts model in eval mode and does everything needed for testing.
|
|
|
|
|
2020-03-03 15:52:16 +00:00
|
|
|
.. warning:: .test() is not stable yet on TPUs. We're working on getting around the multiprocessing challenges.
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
---------
|
|
|
|
|
2020-03-02 03:15:55 +00:00
|
|
|
Predicting
|
|
|
|
----------
|
|
|
|
Again, a LightningModule is exactly the same as a PyTorch module. This means you can load it
|
|
|
|
and use it for prediction.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
model = LitMNIST.load_from_checkpoint(PATH)
|
2020-03-02 03:15:55 +00:00
|
|
|
x = torch.Tensor(1, 1, 28, 28)
|
|
|
|
out = model(x)
|
|
|
|
|
|
|
|
On the surface, it looks like `forward` and `training_step` are similar. Generally, we want to make sure that
|
|
|
|
what we want the model to do is what happens in the `forward`. whereas the `training_step` likely calls forward from
|
|
|
|
within it.
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
class MNISTClassifier(pl.LightningModule):
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
batch_size, channels, width, height = x.size()
|
|
|
|
x = x.view(batch_size, -1)
|
|
|
|
x = self.layer_1(x)
|
|
|
|
x = torch.relu(x)
|
|
|
|
x = self.layer_2(x)
|
|
|
|
x = torch.relu(x)
|
|
|
|
x = self.layer_3(x)
|
|
|
|
x = torch.log_softmax(x, dim=1)
|
|
|
|
return x
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
2020-03-27 07:17:56 +00:00
|
|
|
logits = self(x)
|
2020-03-02 03:15:55 +00:00
|
|
|
loss = F.nll_loss(logits, y)
|
|
|
|
return loss
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
model = MNISTClassifier()
|
|
|
|
x = mnist_image()
|
|
|
|
logits = model(x)
|
|
|
|
|
2020-03-02 03:15:55 +00:00
|
|
|
In this case, we've set this LightningModel to predict logits. But we could also have it predict feature maps:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
class MNISTRepresentator(pl.LightningModule):
|
2020-03-02 03:15:55 +00:00
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
batch_size, channels, width, height = x.size()
|
|
|
|
x = x.view(batch_size, -1)
|
|
|
|
x = self.layer_1(x)
|
|
|
|
x1 = torch.relu(x)
|
|
|
|
x = self.layer_2(x1)
|
|
|
|
x2 = torch.relu(x)
|
|
|
|
x3 = self.layer_3(x2)
|
|
|
|
return [x, x1, x2, x3]
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
2020-03-27 07:17:56 +00:00
|
|
|
out, l1_feats, l2_feats, l3_feats = self(x)
|
2020-03-02 03:15:55 +00:00
|
|
|
logits = torch.log_softmax(out, dim=1)
|
|
|
|
ce_loss = F.nll_loss(logits, y)
|
|
|
|
loss = perceptual_loss(l1_feats, l2_feats, l3_feats) + ce_loss
|
|
|
|
return loss
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
model = MNISTRepresentator.load_from_checkpoint(PATH)
|
|
|
|
x = mnist_image()
|
|
|
|
feature_maps = model(x)
|
|
|
|
|
|
|
|
Or maybe we have a model that we use to do generation
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
class LitMNISTDreamer(pl.LightningModule):
|
2020-03-03 21:42:49 +00:00
|
|
|
|
|
|
|
def forward(self, z):
|
|
|
|
imgs = self.decoder(z)
|
|
|
|
return imgs
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
x, y = batch
|
|
|
|
representation = self.encoder(x)
|
2020-03-27 07:17:56 +00:00
|
|
|
imgs = self(representation)
|
2020-03-03 21:42:49 +00:00
|
|
|
|
|
|
|
loss = perceptual_loss(imgs, x)
|
|
|
|
return loss
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
model = LitMNISTDreamer.load_from_checkpoint(PATH)
|
2020-03-03 21:42:49 +00:00
|
|
|
z = sample_noise()
|
|
|
|
generated_imgs = model(z)
|
|
|
|
|
2020-03-02 03:15:55 +00:00
|
|
|
How you split up what goes in `forward` vs `training_step` depends on how you want to use this model for
|
|
|
|
prediction.
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
---------
|
|
|
|
|
2020-03-03 15:52:16 +00:00
|
|
|
Extensibility
|
|
|
|
-------------
|
|
|
|
Although lightning makes everything super simple, it doesn't sacrifice any flexibility or control.
|
|
|
|
Lightning offers multiple ways of managing the training state.
|
|
|
|
|
|
|
|
Training overrides
|
|
|
|
^^^^^^^^^^^^^^^^^^
|
|
|
|
|
|
|
|
Any part of the training, validation and testing loop can be modified.
|
|
|
|
For instance, if you wanted to do your own backward pass, you would override the
|
|
|
|
default implementation
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def backward(self, use_amp, loss, optimizer):
|
|
|
|
if use_amp:
|
|
|
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
|
|
|
scaled_loss.backward()
|
|
|
|
else:
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
With your own
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
class LitMNIST(pl.LightningModule):
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
def backward(self, use_amp, loss, optimizer):
|
|
|
|
# do a custom way of backward
|
|
|
|
loss.backward(retain_graph=True)
|
|
|
|
|
|
|
|
Or if you wanted to initialize ddp in a different way than the default one
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
def configure_ddp(self, model, device_ids):
|
|
|
|
# Lightning DDP simply routes to test_step, val_step, etc...
|
|
|
|
model = LightningDistributedDataParallel(
|
|
|
|
model,
|
|
|
|
device_ids=device_ids,
|
|
|
|
find_unused_parameters=True
|
|
|
|
)
|
|
|
|
return model
|
|
|
|
|
|
|
|
you could do your own:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-03-06 11:25:24 +00:00
|
|
|
class LitMNIST(pl.LightningModule):
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
def configure_ddp(self, model, device_ids):
|
|
|
|
|
|
|
|
model = Horovod(model)
|
|
|
|
# model = Ray(model)
|
|
|
|
return model
|
|
|
|
|
|
|
|
Every single part of training is configurable this way.
|
|
|
|
For a full list look at `lightningModule <lightning-module.rst>`_.
|
|
|
|
|
2020-03-06 17:12:39 +00:00
|
|
|
---------
|
2020-03-03 15:52:16 +00:00
|
|
|
|
|
|
|
Callbacks
|
|
|
|
---------
|
|
|
|
Another way to add arbitrary functionality is to add a custom callback
|
|
|
|
for hooks that you might care about
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
|
|
|
|
class MyPrintingCallback(pl.Callback):
|
|
|
|
|
|
|
|
def on_init_start(self, trainer):
|
|
|
|
print('Starting to init trainer!')
|
|
|
|
|
|
|
|
def on_init_end(self, trainer):
|
|
|
|
print('trainer is init now')
|
|
|
|
|
|
|
|
def on_train_end(self, trainer, pl_module):
|
|
|
|
print('do something when training ends')
|
|
|
|
|
|
|
|
And pass the callbacks into the trainer
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
Trainer(callbacks=[MyPrintingCallback()])
|
|
|
|
|
2020-03-20 19:49:01 +00:00
|
|
|
.. note::
|
|
|
|
See full list of 12+ hooks in the :ref:`callbacks`.
|
2020-03-03 15:52:16 +00:00
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
---------
|
|
|
|
|
2020-03-03 15:52:16 +00:00
|
|
|
.. include:: child_modules.rst
|
|
|
|
|
2020-03-03 21:42:49 +00:00
|
|
|
---------
|
|
|
|
|
2020-03-03 15:52:16 +00:00
|
|
|
.. include:: transfer_learning.rst
|
|
|
|
|