Docs clean up of results and forward vs training_step confusion (#3584)

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs

* docs
This commit is contained in:
William Falcon 2020-09-21 11:17:59 -04:00 committed by GitHub
parent b1347c956a
commit 32303f1022
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 200 additions and 117 deletions

View File

@ -132,7 +132,7 @@ Get started with our [3 steps guide](https://pytorch-lightning.readthedocs.io/en
## How To Use
##### Install
#### Setup step: Install
Simple installation from PyPI
```bash
pip install pytorch-lightning
@ -148,63 +148,57 @@ Install bleeding-edge (no guarantees)
pip install git+https://github.com/PytorchLightning/pytorch-lightning.git@master --upgrade
```
##### Here's a minimal example without a test loop.
#### Setup step: Add these imports
```python
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
from pytorch_lightning import Trainer
```
#### Step 1: Define a LightningModule
A LightningModule defines a full *system* (ie: a GAN, autoencoder, BERT or a simple Image Classifier).
```python
# this is just a plain nn.Module with some structure
class LitClassifier(pl.LightningModule):
class LitAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
result = pl.TrainResult(loss)
result.log('train_loss', loss, on_epoch=True)
return result
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
result = pl.EvalResult(checkpoint_on=loss)
result.log('val_loss', loss)
return result
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
```
# train!
#### Step 2: Train!
```python
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])
model = LitClassifier()
trainer = Trainer()
trainer.fit(model, DataLoader(train), DataLoader(val))
autoencoder = LitAutoEncoder()
trainer = pl.Trainer()
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))
```
#### And without changing a single line of code, you could run on GPUs
```python
# 8 GPUs
trainer = Trainer(max_epochs=1, gpus=8)
# 256 GPUs

View File

@ -61,17 +61,15 @@ You could also use conda environments
conda activate my_env
pip install pytorch-lightning
----------
******************************
Step 1: Define LightningModule
******************************
Import the following:
.. code-block::
.. code-block:: python
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
@ -79,34 +77,44 @@ Step 1: Define LightningModule
import pytorch_lightning as pl
from torch.utils.data import random_split
class LitModel(pl.LightningModule):
******************************
Step 1: Define LightningModule
******************************
.. code-block::
class LitAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 10)
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
def forward(self, x):
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
return x
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
A :class:`~pytorch_lightning.core.LightningModule` defines a *system* such as:
# (log keyword is optional)
return {'loss': loss, 'log': {'train_loss': loss}}
- Autoencoder
- BERT
- DQN
- GAN
- Image classifier
- Seq2seq
- SimCLR
- VAE
The :class:`~pytorch_lightning.core.LightningModule` holds your research code:
It is a :class:`torch.nn.Module` that groups all research code into a single file to make it self-contained:
- The Train loop
- The Validation loop
@ -114,32 +122,25 @@ The :class:`~pytorch_lightning.core.LightningModule` holds your research code:
- The Model + system architecture
- The Optimizer
A :class:`~pytorch_lightning.core.LightningModule` is a :class:`torch.nn.Module` but with added functionality.
It organizes your research code into :ref:`hooks`.
In the snippet above we override the basic hooks, but a full list of hooks to customize can be found under :ref:`hooks`.
You can use your :class:`~pytorch_lightning.core.LightningModule` just like a PyTorch model.
You can customize any part of training (such as the backward pass) by overriding any
of the 20+ hooks found in :ref:`hooks`
.. code-block:: python
model = LitModel()
model.eval()
class LitAutoEncoder(pl.LightningModule):
y_hat = model(x)
model.anything_you_can_do_with_pytorch()
def backward(self, trainer, loss, optimizer, optimizer_idx):
loss.backward()
More details in :ref:`lightning_module` docs.
----------
**************************
Step 2: Fit with a Trainer
**************************
First, define the data in whatever way you want. Lightning just needs a dataloader per split you might want.
First, define the data however you want. Lightning just needs a dataloader for any split you want (train/val/test).
.. code-block:: python
@ -149,13 +150,83 @@ First, define the data in whatever way you want. Lightning just needs a dataload
.. code-block:: python
# init model
model = LitModel()
model = LitAutoEncoder()
# most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more)
# trainer = pl.Trainer(gpus=8) (if you have GPUs)
trainer = pl.Trainer()
trainer.fit(model, train_loader)
-----
*****************
Predict or Deploy
*****************
When you're done training, you have 3 options to use your LightningModule for predictions.
Option 1: Pull out the relevant parts you need for prediction
.. code-block:: python
# ----------------------------------
# to use as embedding extractor
# ----------------------------------
autoencoder = LitAutoEncoder.load_from_checkpoint('path/to/checkpoint_file.ckpt')
model = autoencoder.encoder
model.eval()
# ----------------------------------
# to use as image generator
# ----------------------------------
model = autoencoder.decoder
model.eval()
Option 2: Add a forward method to enable predictions however you want.
.. code-block:: python
# ----------------------------------
# using the AE to extract embeddings
# ----------------------------------
class LitAutoEncoder(pl.LightningModule):
def forward(self, x):
embedding = self.encoder(x)
autoencoder = LitAutoencoder()
autoencoder = autoencoder(torch.rand(1, 28 * 28))
# ----------------------------------
# or using the AE to generate images
# ----------------------------------
class LitAutoEncoder(pl.LightningModule):
def forward(self):
z = torch.rand(1, 28 * 28)
image = self.decoder(z)
image = image.view(1, 1, 28, 28)
return image
autoencoder = LitAutoencoder()
image_sample = autoencoder(()
Option 3: Or for a production system
.. code-block:: python
# ----------------------------------
# torchscript
# ----------------------------------
model = LitAutoEncoder()
torch.jit.save(model.to_torchscript(), "model.pt")
os.path.isfile("model.pt")
# ----------------------------------
# onnx
# ----------------------------------
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmpfile:
model = LitAutoEncoder()
input_sample = torch.randn((1, 28 * 28))
model.to_onnx(tmpfile.name, input_sample, export_params=True)
os.path.isfile(tmpfile.name)
-----------
***********
Checkpoints
@ -186,81 +257,99 @@ Optional features
TrainResult/EvalResult
======================
Instead of returning the loss you can also use :class:`~pytorch_lightning.core.step_result.TrainResult` and :class:`~pytorch_lightning.core.step_result.EvalResult`, plain Dict objects that give you options for logging on every step and/or at the end of the epoch.
It also allows logging to the progress bar (by setting prog_bar=True). Read more in :ref:`results`.
If you want to log to Tensorboard or your favorite logger, and/or the progress bar, use the
:class:`~pytorch_lightning.core.step_result.TrainResult` object.
.. code-block::
class LitModel(pl.LightningModule):
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
...
loss = F.mse_loss(x_hat, x)
result = pl.TrainResult(minimize=loss)
# Add logging to progress bar (note that refreshing the progress bar too frequently
# in Jupyter notebooks or Colab may freeze your UI)
result.log('train_loss', loss, prog_bar=True)
return result
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
# Checkpoint model based on validation loss
result = pl.EvalResult(checkpoint_on=loss)
result.log('val_loss', loss)
return result
Callbacks
=========
A :class:`~pytorch_lightning.core.LightningModule` handles advances cases by allowing you to override any critical part of training
via :ref:`hooks` that are called on your :class:`~pytorch_lightning.core.LightningModule`.
# equivalent
def training_step(self, batch, batch_idx):
...
loss = F.mse_loss(x_hat, x)
return loss
To enable logging:
.. code-block::
class LitModel(pl.LightningModule):
def backward(self, trainer, loss, optimizer, optimizer_idx):
loss.backward()
def optimizer_step(self, epoch, batch_idx,
optimizer, optimizer_idx,
second_order_closure,
on_tpu, using_native_amp, using_lbfgs):
optimizer.step()
For certain train/val/test loops, you may wish to do more than just logging. In this case,
you can also implement `__epoch_end` which gives you the output for each step
def training_step(self, batch, batch_idx):
...
loss = F.mse_loss(x_hat, x)
result = pl.TrainResult(minimize=loss)
Here's the motivating Pytorch example:
# .log sends to tensorboard/logger, prog_bar also sends to the progress bar
result.log('my_train_loss', loss, prog_bar=True)
return result
And for the validation loop use the :class:`~pytorch_lightning.core.step_result.EvalResult` object.
.. code-block:: python
validation_step_outputs = []
for batch_idx, batch in val_dataloader():
out = validation_step(batch, batch_idx)
validation_step_outputs.append(out)
validation_epoch_end(validation_step_outputs)
And the lightning equivalent
.. code-block::
class LitModel(pl.LightningModule):
def validation_step(self, batch, batch_idx):
loss = ...
predictions = ...
...
loss = F.mse_loss(x_hat, x)
# lightning monitors 'checkpoint_on' to know when to checkpoint (this is a tensor)
result = pl.EvalResult(checkpoint_on=loss)
result.log('val_loss', loss)
result.predictions = predictions
return result
def validation_epoch_end(self, validation_step_outputs):
all_val_losses = validation_step_outputs.val_loss
all_predictions = validation_step_outputs.predictions
.. note:: A Result Object is just a dictionary (print it to verify for yourself!)
Callbacks
=========
A callback is an arbitrary self-contained program that can be executed at arbitrary parts of the training loop.
Things you can do with a callback:
- send emails at some point in training
- grow the model
- update learning rates
- visualize gradients
- ...
- you are limited by your imagination
Here's an example adding a not-so-fancy learning rate decay rule:
.. code-block:: python
class DecayLearningRate(pl.Callback)
def __init__(self):
self.old_lrs = []
def on_train_start(self, trainer, pl_module):
# track the initial learning rates
for opt_idx in optimizer in enumerate(trainer.optimizers):
group = []
for param_group in optimizer.param_groups:
group.append(param_group['lr'])
self.old_lrs.append(group)
def on_train_epoch_end(self, trainer, pl_module):
for opt_idx in optimizer in enumerate(trainer.optimizers):
old_lr_group = self.old_lrs[opt_idx]
new_lr_group = []
for p_idx, param_group in enumerate(optimizer.param_groups):
old_lr = old_lr_group[p_idx]
new_lr = old_lr * 0.98
new_lr_group.append(new_lr)
param_group['lr'] = new_lr
self.old_lrs[opt_idx] = new_lr_group
Datamodules
===========