lightning/README.md

346 lines
14 KiB
Markdown
Raw Normal View History

2019-03-31 19:32:35 +00:00
<p align="center">
<a href="https://williamfalcon.github.io/pytorch-lightning/">
<img alt="" src="https://github.com/williamFalcon/pytorch-lightning/blob/master/docs/source/_static/lightning_logo.png" width="50">
2019-03-31 19:32:35 +00:00
</a>
</p>
<h3 align="center">
2019-08-01 14:11:26 +00:00
PyTorch Lightning
2019-03-31 19:32:35 +00:00
</h3>
<p align="center">
2019-08-01 14:02:12 +00:00
The PyTorch Keras for ML researchers. More control. Less boilerplate.
2019-03-31 19:32:35 +00:00
</p>
2019-07-26 00:13:22 +00:00
2019-03-31 19:32:35 +00:00
<p align="center">
2019-03-31 20:59:24 +00:00
<a href="https://badge.fury.io/py/pytorch-lightning"><img src="https://badge.fury.io/py/pytorch-lightning.svg" alt="PyPI version" height="18"></a>
2019-07-25 13:55:30 +00:00
<a href="https://pepy.tech/project/pytorch-lightning"><img src="https://pepy.tech/badge/pytorch-lightning" alt="PyPI version" height="18"></a>
2019-07-25 01:15:17 +00:00
<a href="https://github.com/williamFalcon/pytorch-lightning/tree/master/tests"><img src="https://github.com/williamFalcon/pytorch-lightning/blob/master/coverage.svg"></a>
2019-07-25 01:14:43 +00:00
<a href="https://travis-ci.org/williamFalcon/pytorch-lightning"><img src="https://travis-ci.org/williamFalcon/pytorch-lightning.svg?branch=master"></a>
2019-07-25 00:52:08 +00:00
<a href="https://williamfalcon.github.io/pytorch-lightning/"><img src="https://readthedocs.org/projects/pytorch-lightning/badge/?version=latest"></a>
<a href="https://github.com/williamFalcon/pytorch-lightning/blob/master/COPYING"><img src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>
2019-03-31 19:32:35 +00:00
</p>
```bash
2019-07-25 01:22:24 +00:00
pip install pytorch-lightning
2019-03-31 19:32:35 +00:00
```
2019-03-31 00:50:32 +00:00
2019-03-31 19:33:05 +00:00
## Docs
2019-06-26 23:18:41 +00:00
**[View the docs here](https://williamfalcon.github.io/pytorch-lightning/)**
2019-03-31 19:39:39 +00:00
## What is it?
2019-07-25 14:23:51 +00:00
Lightning defers training and validation loop logic to you. It guarantees correct, modern best practices for the core training logic.
2019-03-31 20:50:32 +00:00
2019-06-26 23:47:31 +00:00
## Why do I want to use lightning?
2019-08-03 12:16:00 +00:00
When starting a new project the last thing you want to do is recode a training loop, multi-cluster training, 16-bit precision, early-stopping, model loading/saving, when to validate, etc... You're likely to spend a long time ironing out all the bugs without even getting to the core of your research.
2019-07-25 14:23:51 +00:00
2019-08-03 12:16:00 +00:00
With lightning, you guarantee those parts of your code work so you can focus on what the meat of the research: The data and the training/validation loop logic.
Don't worry about training on multiple gpus or speeding up your code, lightning will do that for you!
2019-06-26 23:47:31 +00:00
2019-07-25 16:33:53 +00:00
## How do I do use it?
2019-06-26 23:58:33 +00:00
To use lightning do 2 things:
2019-07-25 14:11:51 +00:00
1. [Define a LightningModel](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/)
```python
2019-07-27 18:26:08 +00:00
import os
2019-07-25 14:11:51 +00:00
import torch
2019-07-25 15:52:54 +00:00
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
2019-07-27 18:26:08 +00:00
import torchvision.transforms as transforms
import pytorch_lightning as ptl
2019-07-25 14:11:51 +00:00
2019-07-25 15:28:34 +00:00
class CoolModel(ptl.LightningModule):
2019-07-25 14:11:51 +00:00
2019-07-27 18:26:08 +00:00
def __init__(self):
2019-07-25 16:11:49 +00:00
super(CoolModel, self).__init__()
2019-07-25 15:52:54 +00:00
# not the best model...
2019-07-25 16:11:49 +00:00
self.l1 = torch.nn.Linear(28 * 28, 10)
2019-07-25 14:11:51 +00:00
def forward(self, x):
2019-07-27 18:26:08 +00:00
return torch.relu(self.l1(x.view(x.size(0), -1)))
2019-07-25 16:11:49 +00:00
2019-07-25 15:52:54 +00:00
def my_loss(self, y_hat, y):
return F.cross_entropy(y_hat, y)
2019-07-25 16:11:49 +00:00
2019-07-25 14:11:51 +00:00
def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
2019-07-27 18:26:08 +00:00
return {'loss': self.my_loss(y_hat, y)}
2019-07-25 16:11:49 +00:00
2019-07-25 14:11:51 +00:00
def validation_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
2019-07-25 15:52:54 +00:00
return {'val_loss': self.my_loss(y_hat, y)}
2019-07-25 16:11:49 +00:00
2019-07-25 15:58:06 +00:00
def validation_end(self, outputs):
2019-07-27 18:26:08 +00:00
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'avg_val_loss': avg_loss}
2019-07-25 16:11:49 +00:00
2019-07-25 14:11:51 +00:00
def configure_optimizers(self):
2019-07-27 18:38:33 +00:00
return [torch.optim.Adam(self.parameters(), lr=0.02)]
2019-07-25 16:11:49 +00:00
2019-07-25 15:28:34 +00:00
@ptl.data_loader
2019-07-25 14:11:51 +00:00
def tng_dataloader(self):
2019-07-27 18:26:08 +00:00
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
2019-07-25 14:11:51 +00:00
2019-07-25 15:28:34 +00:00
@ptl.data_loader
2019-07-25 14:11:51 +00:00
def val_dataloader(self):
2019-07-27 18:26:08 +00:00
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
2019-07-25 16:11:49 +00:00
2019-07-25 15:28:34 +00:00
@ptl.data_loader
2019-07-25 14:11:51 +00:00
def test_dataloader(self):
2019-07-27 18:26:08 +00:00
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
2019-07-25 14:11:51 +00:00
```
2. Fit with a [trainer](https://williamfalcon.github.io/pytorch-lightning/Trainer/)
```python
from pytorch_lightning import Trainer
from test_tube import Experiment
2019-07-27 18:26:08 +00:00
model = CoolModel()
exp = Experiment(save_dir=os.getcwd())
2019-07-27 18:28:44 +00:00
2019-07-27 18:33:48 +00:00
# train on cpu using only 10% of the data (for demo purposes)
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1)
2019-07-25 14:11:51 +00:00
2019-07-27 18:28:44 +00:00
# train on 4 gpus
# trainer = Trainer(experiment=exp, max_nb_epochs=1, gpus=[0, 1, 2, 3])
# train on 32 gpus across 4 nodes (make sure to submit appropriate SLURM job)
# trainer = Trainer(experiment=exp, max_nb_epochs=1, gpus=[0, 1, 2, 3, 4, 5, 6, 7], nb_gpu_nodes=4)
2019-07-27 18:26:08 +00:00
# train (1 epoch only here for demo)
2019-07-25 14:11:51 +00:00
trainer.fit(model)
2019-07-27 18:26:08 +00:00
# view tensorflow logs
2019-07-27 18:31:22 +00:00
print(f'View tensorboard logs by running\ntensorboard --logdir {os.getcwd()}')
print('and going to http://localhost:6006 on your browser')
2019-07-25 14:11:51 +00:00
```
2019-06-26 23:44:41 +00:00
2019-08-04 12:10:39 +00:00
## What does lightning control for me?
2019-08-04 12:31:08 +00:00
Everything in gray!
You define the blue parts using the LightningModule interface:
2019-08-04 12:10:39 +00:00
<p align="center">
2019-08-04 12:30:09 +00:00
<a href="https://github.com/williamFalcon/pytorch-lightning/blob/master/docs/source/_static/overview_flat.jpg">
2019-08-04 12:30:39 +00:00
<img alt="" src="https://github.com/williamFalcon/pytorch-lightning/blob/master/docs/source/_static/overview_flat.jpg" height="700px">
2019-08-04 12:10:39 +00:00
</a>
2019-08-04 12:31:08 +00:00
</p>
2019-06-29 22:08:57 +00:00
```{.python}
# what to do in the training loop
def training_step(self, data_batch, batch_nb):
# what to do in the validation loop
def validation_step(self, data_batch, batch_nb):
# how to aggregate validation_step outputs
def validation_end(self, outputs):
# and your dataloaders
def tng_dataloader():
def val_dataloader():
def test_dataloader():
```
2019-06-26 23:44:41 +00:00
2019-06-29 22:06:30 +00:00
**Could be as complex as seq-2-seq + attention**
2019-06-27 18:44:51 +00:00
2019-06-26 23:44:41 +00:00
```python
# define what happens for training here
2019-06-27 00:00:53 +00:00
def training_step(self, data_batch, batch_nb):
x, y = data_batch
2019-06-27 18:43:10 +00:00
# define your own forward and loss calculation
2019-06-29 22:05:17 +00:00
hidden_states = self.encoder(x)
# even as complex as a seq-2seq + attn model
# (this is just a toy, non-working example to illustrate)
start_token = '<SOS>'
last_hidden = torch.zeros(...)
loss = 0
for step in range(max_seq_len):
attn_context = self.attention_nn(hidden_states, start_token)
pred = self.decoder(start_token, attn_context, last_hidden)
last_hidden = pred
pred = self.predict_nn(pred)
loss += self.loss(last_hidden, y[step])
#toy example as well
loss = loss / max_seq_len
2019-06-27 00:00:53 +00:00
return {'loss': loss}
2019-06-26 23:44:41 +00:00
```
2019-06-27 18:44:51 +00:00
2019-06-29 22:06:30 +00:00
**Or as basic as CNN image classification**
2019-06-26 23:44:41 +00:00
```python
# define what happens for validation here
2019-06-27 18:43:10 +00:00
def validation_step(self, data_batch, batch_nb):
x, y = data_batch
2019-06-29 22:05:17 +00:00
# or as basic as a CNN classification
2019-06-27 00:00:53 +00:00
out = self.forward(x)
loss = my_loss(out, y)
return {'loss': loss}
2019-06-26 23:44:41 +00:00
```
2019-06-29 22:06:30 +00:00
**And you also decide how to collate the output of all validation steps**
2019-06-27 18:44:51 +00:00
```python
def validation_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
:param outputs: list of individual outputs of each validation step
:return:
"""
val_loss_mean = 0
val_acc_mean = 0
for output in outputs:
val_loss_mean += output['val_loss']
val_acc_mean += output['val_acc']
val_loss_mean /= len(outputs)
val_acc_mean /= len(outputs)
tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
return tqdm_dic
```
2019-06-29 21:57:40 +00:00
2019-07-15 13:21:30 +00:00
## Tensorboard
Lightning is fully integrated with tensorboard.
2019-06-29 22:32:55 +00:00
2019-06-29 22:28:11 +00:00
<p align="center">
2019-08-04 12:13:34 +00:00
<a href="https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#tensorboard-support">
2019-06-29 22:33:27 +00:00
<img alt="" src="https://github.com/williamFalcon/pytorch-lightning/blob/master/docs/source/_static/tf_loss.png" width="900px">
2019-06-29 22:28:11 +00:00
</a>
</p>
2019-06-29 22:35:41 +00:00
Lightning also adds a text column with all the hyperparameters for this experiment.
2019-06-29 22:32:55 +00:00
<p align="center">
2019-08-04 12:13:34 +00:00
<a href="https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#tensorboard-support">
2019-06-29 22:33:27 +00:00
<img alt="" src="https://github.com/williamFalcon/pytorch-lightning/blob/master/docs/source/_static/tf_tags.png" width="900px">
2019-06-29 22:32:55 +00:00
</a>
</p>
2019-06-29 21:57:40 +00:00
Simply note the path you set for the Experiment
``` {.python}
from test_tube import Experiment
from pytorch-lightning import Trainer
exp = Experiment(save_dir='/some/path')
trainer = Trainer(experiment=exp)
...
```
2019-06-29 21:58:10 +00:00
And run tensorboard from that dir
2019-06-29 21:57:40 +00:00
```bash
tensorboard --logdir /some/path
```
2019-06-26 23:44:41 +00:00
2019-07-25 16:44:48 +00:00
## Lightning automates all of the following ([each is also configurable](https://williamfalcon.github.io/pytorch-lightning/Trainer/)):
2019-06-26 23:44:41 +00:00
2019-07-28 11:59:16 +00:00
2019-06-28 22:48:09 +00:00
###### Checkpointing
2019-06-26 23:44:41 +00:00
2019-06-28 22:48:09 +00:00
- [Model saving](https://williamfalcon.github.io/pytorch-lightning/Trainer/Checkpointing/#model-saving)
- [Model loading](https://williamfalcon.github.io/pytorch-lightning/LightningModule/methods/#load-from-metrics)
2019-06-26 23:44:41 +00:00
2019-06-28 22:49:18 +00:00
###### Computing cluster (SLURM)
2019-06-26 23:44:41 +00:00
2019-06-28 22:48:09 +00:00
- [Running grid search on a cluster](https://williamfalcon.github.io/pytorch-lightning/Trainer/SLURM%20Managed%20Cluster#running-grid-search-on-a-cluster)
- [Walltime auto-resubmit](https://williamfalcon.github.io/pytorch-lightning/Trainer/SLURM%20Managed%20Cluster#walltime-auto-resubmit)
2019-06-26 23:44:41 +00:00
2019-06-28 22:49:18 +00:00
###### Debugging
2019-06-26 23:44:41 +00:00
2019-06-28 18:44:57 +00:00
- [Fast dev run](https://williamfalcon.github.io/pytorch-lightning/Trainer/debugging/#fast-dev-run)
- [Inspect gradient norms](https://williamfalcon.github.io/pytorch-lightning/Trainer/debugging/#inspect-gradient-norms)
- [Log GPU usage](https://williamfalcon.github.io/pytorch-lightning/Trainer/debugging/#Log-gpu-usage)
- [Make model overfit on subset of data](https://williamfalcon.github.io/pytorch-lightning/Trainer/debugging/#make-model-overfit-on-subset-of-data)
- [Print the parameter count by layer](https://williamfalcon.github.io/pytorch-lightning/Trainer/debugging/#print-the-parameter-count-by-layer)
2019-08-04 03:52:29 +00:00
- [Print which gradients are nan](https://williamfalcon.github.io/pytorch-lightning/Trainer/debugging/#print-which-gradients-are-nan)
2019-06-26 23:44:41 +00:00
2019-06-28 22:49:18 +00:00
###### Distributed training
2019-03-31 19:39:39 +00:00
2019-06-28 18:45:49 +00:00
- [16-bit mixed precision](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#16-bit-mixed-precision)
- [Multi-GPU](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#Multi-GPU)
- [Multi-node](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#Multi-node)
- [Single GPU](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#single-gpu)
- [Self-balancing architecture](https://williamfalcon.github.io/pytorch-lightning/Trainer/Distributed%20training/#self-balancing-architecture)
2019-06-26 23:58:33 +00:00
2019-06-28 22:49:18 +00:00
###### Experiment Logging
2019-06-26 23:58:33 +00:00
2019-06-28 18:46:28 +00:00
- [Display metrics in progress bar](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#display-metrics-in-progress-bar)
- [Log metric row every k batches](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#log-metric-row-every-k-batches)
2019-06-28 22:48:09 +00:00
- [Process position](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#process-position)
2019-07-28 11:59:16 +00:00
- [Tensorboard support](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#tensorboard-support)
2019-06-28 18:46:28 +00:00
- [Save a snapshot of all hyperparameters](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#save-a-snapshot-of-all-hyperparameters)
- [Snapshot code for a training run](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#snapshot-code-for-a-training-run)
- [Write logs file to csv every k batches](https://williamfalcon.github.io/pytorch-lightning/Trainer/Logging/#write-logs-file-to-csv-every-k-batches)
2019-06-27 18:43:10 +00:00
2019-06-28 22:49:18 +00:00
###### Training loop
2019-06-27 18:43:10 +00:00
2019-06-28 18:48:19 +00:00
- [Accumulate gradients](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#accumulated-gradients)
- [Force training for min or max epochs](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-training-for-min-or-max-epochs)
- [Force disable early stop](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#force-disable-early-stop)
2019-06-28 22:48:09 +00:00
- [Gradient Clipping](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#gradient-clipping)
2019-07-28 11:59:16 +00:00
- [Hooks](https://williamfalcon.github.io/pytorch-lightning/Trainer/hooks/)
2019-07-28 14:00:53 +00:00
- [Learning rate scheduling](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/#configure_optimizers)
- [Use multiple optimizers (like GANs)](https://williamfalcon.github.io/pytorch-lightning/LightningModule/RequiredTrainerInterface/#configure_optimizers)
2019-06-28 18:48:19 +00:00
- [Set how much of the training set to check (1-100%)](https://williamfalcon.github.io/pytorch-lightning/Trainer/Training%20Loop/#set-how-much-of-the-training-set-to-check)
2019-06-27 18:43:10 +00:00
2019-07-28 12:04:28 +00:00
###### Validation loop
2019-06-27 18:43:10 +00:00
2019-06-28 18:48:51 +00:00
- [Check validation every n epochs](https://williamfalcon.github.io/pytorch-lightning/Trainer/Validation%20loop/#check-validation-every-n-epochs)
2019-07-28 11:59:16 +00:00
- [Hooks](https://williamfalcon.github.io/pytorch-lightning/Trainer/hooks/)
2019-06-28 18:48:51 +00:00
- [Set how much of the validation set to check](https://williamfalcon.github.io/pytorch-lightning/Trainer/Validation%20loop/#set-how-much-of-the-validation-set-to-check)
- [Set how much of the test set to check](https://williamfalcon.github.io/pytorch-lightning/Trainer/Validation%20loop/#set-how-much-of-the-test-set-to-check)
- [Set validation check frequency within 1 training epoch](https://williamfalcon.github.io/pytorch-lightning/Trainer/Validation%20loop/#set-validation-check-frequency-within-1-training-epoch)
- [Set the number of validation sanity steps](https://williamfalcon.github.io/pytorch-lightning/Trainer/Validation%20loop/#set-the-number-of-validation-sanity-steps)
2019-03-31 01:47:51 +00:00
2019-07-28 11:59:16 +00:00
2019-06-27 00:02:51 +00:00
## Demo
2019-06-25 22:40:34 +00:00
```bash
# install lightning
pip install pytorch-lightning
# clone lightning for the demo
git clone https://github.com/williamFalcon/pytorch-lightning.git
2019-07-25 16:30:18 +00:00
cd pytorch_lightning/examples/new_project_templates/
2019-06-25 22:40:34 +00:00
2019-07-25 16:46:11 +00:00
# all of the following demos use the SAME model to show no modification needs to be made to your code
2019-06-25 22:40:34 +00:00
2019-07-25 16:30:18 +00:00
# train on cpu
python single_cpu_template.py
# train on multiple-gpus
python single_gpu_node_template.py --gpus "0,1"
2019-07-25 16:30:59 +00:00
# train on 32 gpus on a cluster (run on a SLURM managed cluster)
2019-07-25 16:30:18 +00:00
python multi_node_cluster_template.py --nb_gpu_nodes 4 --gpus '0,1,2,3,4,5,6,7'
2019-06-25 22:40:34 +00:00
```
2019-08-03 12:21:25 +00:00
## Contributing
Welcome to the PTL community! We're building the most advanced research platform on the planet to implement the latest, best practices that the amazing PyTorch team rolls out!
#### Bug fixes:
1. Submit a github issue.
2. Fix it.
3. Submit a PR!
#### New Features:
1. Submit a github issue.
2. We'll agree on the feature scope.
3. Submit a PR! (with updated docs and tests 🙃).
2019-07-25 23:55:22 +00:00
## Bleeding edge
If you can't wait for the next release, install the most up to date code with:
```bash
pip install git+https://github.com/williamFalcon/pytorch-lightning.git@master --upgrade
2019-07-27 18:26:08 +00:00
```