2019-03-31 19:32:35 +00:00
< p align = "center" >
< a href = "https://williamfalcon.github.io/pytorch-lightning/" >
2019-04-03 16:40:03 +00:00
< img alt = "" src = "https://github.com/williamFalcon/pytorch-lightning/blob/master/docs/source/_static/lightning_logo.png" width = "50" >
2019-03-31 19:32:35 +00:00
< / a >
< / p >
< h3 align = "center" >
Pytorch Lightning
< / h3 >
< p align = "center" >
2019-04-01 16:38:31 +00:00
The Keras for ML researchers using PyTorch. More control. Less boilerplate.
2019-03-31 19:32:35 +00:00
< / p >
< p align = "center" >
2019-03-31 20:59:24 +00:00
< a href = "https://badge.fury.io/py/pytorch-lightning" > < img src = "https://badge.fury.io/py/pytorch-lightning.svg" alt = "PyPI version" height = "18" > < / a >
2019-03-31 20:59:39 +00:00
<!-- <a href="https://travis - ci.org/williamFalcon/test - tube"><img src="https://travis - ci.org/williamFalcon/pytorch - lightning.svg?branch=master"></a> -->
2019-04-03 16:40:03 +00:00
< a href = "https://github.com/williamFalcon/pytorch-lightning/blob/master/COPYING" > < img src = "https://img.shields.io/badge/License-MIT-yellow.svg" > < / a >
2019-03-31 19:32:35 +00:00
< / p >
```bash
pip install pytorch-lightning
```
2019-03-31 00:50:32 +00:00
2019-03-31 19:33:05 +00:00
## Docs
2019-06-26 23:18:41 +00:00
**[View the docs here](https://williamfalcon.github.io/pytorch-lightning/)**
2019-03-31 19:39:39 +00:00
## What is it?
2019-06-26 23:44:41 +00:00
Keras is too abstract for researchers. Lightning abstracts the full training loop but gives you control in the critical points.
2019-03-31 20:50:32 +00:00
2019-06-26 23:47:31 +00:00
## Why do I want to use lightning?
Because you want to use best practices and get gpu training, multi-node training, checkpointing, mixed-precision, etc... for free.
2019-06-26 23:58:33 +00:00
To use lightning do 2 things:
1. [Define a trainer ](https://github.com/williamFalcon/pytorch-lightning/blob/master/docs/source/examples/basic_trainer.py ) (which will run ALL your models).
2. [Define a model ](https://github.com/williamFalcon/pytorch-lightning/blob/master/docs/source/examples/example_model.py ).
2019-06-26 23:44:41 +00:00
## What are some key lightning features?
- Automatic training loop
```python
# define what happens for training here
2019-06-27 00:00:53 +00:00
def training_step(self, data_batch, batch_nb):
x, y = data_batch
out = self.forward(x)
loss = my_loss(out, y)
return {'loss': loss}
2019-06-26 23:44:41 +00:00
```
- Automatic validation loop
```python
# define what happens for validation here
2019-06-27 00:00:53 +00:00
def validation_step(self, data_batch, batch_nb): x, y = data_batch
out = self.forward(x)
loss = my_loss(out, y)
return {'loss': loss}
2019-06-26 23:44:41 +00:00
```
- Automatic early stopping
```python
callback = EarlyStopping(...)
Trainer(early_stopping=callback)
```
- Learning rate annealing
```python
# anneal at 100 and 200 epochs
Trainer(lr_scheduler_milestones=[100, 200])
```
2019-06-27 00:01:29 +00:00
- 16 bit precision training (must have apex installed)
2019-06-26 23:47:31 +00:00
```python
Trainer(use_amp=True, amp_level='O2')
```
2019-06-26 23:44:41 +00:00
- multi-gpu training
```python
# train on 4 gpus
Trainer(gpus=[0, 1, 2, 3])
```
- Automatic checkpointing
```python
# do 3 things:
# 1
Trainer(checkpoint_callback=ModelCheckpoint)
# 2 return what to save in a checkpoint
def get_save_dict(self):
return {'state_dict': self.state_dict()}
# 3 use the checkpoint to reset your model state
def load_model_specific(self, checkpoint):
self.load_state_dict(checkpoint['state_dict'])
```
2019-06-26 23:50:53 +00:00
- Log all details of your experiment (model params, code snapshot, etc...)
```python
from test_tube import Experiment
2019-06-26 23:44:41 +00:00
2019-06-26 23:50:53 +00:00
exp = Experiment(...)
Trainer(experiment=exp)
```
2019-03-31 19:39:39 +00:00
2019-06-26 23:58:33 +00:00
- Run grid-search on cluster
```python
from test_tube import Experiment, SlurmCluster, HyperOptArgumentParser
def training_fx(hparams, cluster, _):
# hparams are local params
model = MyModel()
trainer = Trainer(...)
trainer.fit(model)
# grid search number of layers
parser = HyperOptArgumentParser(strategy='grid_search')
parser.opt_list('--layers', default=5, type=int, options=[1, 5, 10, 20, 50])
hyperparams = parser.parse_args()
cluster = SlurmCluster(hyperparam_optimizer=hyperparams)
cluster.optimize_parallel_cluster_gpu(training_fx)
```
2019-04-03 16:40:03 +00:00
2019-03-31 01:47:51 +00:00
2019-06-25 22:40:34 +00:00
#### Quick demo
Run the following demo to see how it works:
```bash
# install lightning
pip install pytorch-lightning
# clone lightning for the demo
git clone https://github.com/williamFalcon/pytorch-lightning.git
cd pytorch-lightning/docs/source/examples
# run demo (on cpu)
python fully_featured_trainer.py
2019-06-25 22:47:11 +00:00
```
2019-06-25 22:40:34 +00:00
2019-06-25 22:47:11 +00:00
Without changing the model AT ALL, you can run the model on a single gpu, over multiple gpus, or over multiple nodes.
```bash
2019-06-25 22:44:11 +00:00
# run a grid search on two gpus
2019-06-25 22:40:34 +00:00
python fully_featured_trainer.py --gpus "0;1"
2019-06-25 22:44:11 +00:00
# run single model on multiple gpus
python fully_featured_trainer.py --gpus "0;1" --interactive
2019-06-25 22:40:34 +00:00
```
2019-03-31 20:46:00 +00:00
#### Basic trainer example
2019-04-05 20:27:45 +00:00
See [this demo ](https://github.com/williamFalcon/pytorch-lightning/blob/master/docs/source/examples/fully_featured_trainer.py ) for a more robust trainer example.
2019-03-31 01:21:10 +00:00
2019-03-31 01:32:37 +00:00
```python
2019-03-31 20:30:55 +00:00
import os
import sys
from test_tube import HyperOptArgumentParser, Experiment
from pytorch_lightning.models.trainer import Trainer
from pytorch_lightning.utils.arg_parse import add_default_args
2019-03-31 01:32:37 +00:00
from pytorch_lightning.utils.pt_callbacks import EarlyStopping, ModelCheckpoint
2019-03-31 20:30:55 +00:00
from demo.example_model import ExampleModel
def main(hparams):
"""
Main training routine specific for this project
:param hparams:
:return:
"""
# init experiment
exp = Experiment(
name=hparams.tt_name,
debug=hparams.debug,
save_dir=hparams.tt_save_path,
version=hparams.hpc_exp_number,
autosave=False,
description=hparams.tt_description
)
exp.argparse(hparams)
exp.save()
2019-03-31 20:36:29 +00:00
model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version)
2019-03-31 20:30:55 +00:00
# build model
model = ExampleModel(hparams)
# callbacks
2019-03-31 20:35:10 +00:00
early_stop = EarlyStopping(monitor='val_acc', patience=3, mode='min', verbose=True)
2019-03-31 20:35:58 +00:00
checkpoint = ModelCheckpoint(filepath=model_save_path, save_function=None, save_best_only=True, verbose=True, monitor='val_acc', mode='min')
2019-03-31 20:30:55 +00:00
# configure trainer
2019-03-31 20:35:58 +00:00
trainer = Trainer(experiment=exp, checkpoint_callback=checkpoint, early_stop_callback=early_stop)
2019-03-31 20:30:55 +00:00
# train model
trainer.fit(model)
if __name__ == '__main__':
# use default args given by lightning
root_dir = os.path.split(os.path.dirname(sys.modules['__main__'].__file__))[0]
parent_parser = HyperOptArgumentParser(strategy='random_search', add_help=False)
add_default_args(parent_parser, root_dir)
# allow model to overwrite or extend args
parser = ExampleModel.add_model_specific_args(parent_parser)
hyperparams = parser.parse_args()
# train model
main(hyperparams)
2019-03-31 20:34:13 +00:00
2019-03-31 01:32:37 +00:00
```
2019-03-31 20:46:00 +00:00
#### Basic model example
Here we only show the method signatures. It's up to you to define the content.
2019-03-31 01:47:51 +00:00
2019-03-31 01:21:10 +00:00
```python
2019-03-31 01:48:50 +00:00
from torch import nn
2019-03-31 01:21:10 +00:00
2019-03-31 01:48:28 +00:00
class My_Model(RootModule):
2019-03-31 01:21:10 +00:00
def __init__ (self):
2019-03-31 01:24:46 +00:00
# define model
2019-03-31 01:49:22 +00:00
self.l1 = nn.Linear(200, 10)
2019-04-03 16:40:03 +00:00
2019-03-31 01:22:38 +00:00
# ---------------
2019-03-31 01:21:10 +00:00
# TRAINING
def training_step(self, data_batch):
x, y = data_batch
y_hat = self.l1(x)
loss = some_loss(y_hat)
2019-04-03 16:40:03 +00:00
2019-03-31 01:24:46 +00:00
return loss_val, {'train_loss': loss}
2019-04-03 16:40:03 +00:00
2019-03-31 01:21:10 +00:00
def validation_step(self, data_batch):
x, y = data_batch
y_hat = self.l1(x)
loss = some_loss(y_hat)
2019-04-03 16:40:03 +00:00
2019-03-31 01:24:46 +00:00
return loss_val, {'val_loss': loss}
2019-04-03 16:40:03 +00:00
2019-03-31 01:21:10 +00:00
def validation_end(self, outputs):
total_accs = []
2019-04-03 16:40:03 +00:00
2019-03-31 01:21:10 +00:00
for output in outputs:
total_accs.append(output['val_acc'].item())
2019-04-03 16:40:03 +00:00
2019-03-31 01:21:10 +00:00
# return a dict
return {'total_acc': np.mean(total_accs)}
2019-04-03 16:40:03 +00:00
2019-03-31 01:22:38 +00:00
# ---------------
2019-03-31 01:21:10 +00:00
# SAVING
def get_save_dict(self):
# lightning saves for you. Here's your chance to say what you want to save
checkpoint = {'state_dict': self.state_dict()}
return checkpoint
def load_model_specific(self, checkpoint):
# lightning loads for you. Here's your chance to say what you want to load
self.load_state_dict(checkpoint['state_dict'])
2019-04-03 16:40:03 +00:00
2019-03-31 01:22:38 +00:00
# ---------------
2019-03-31 01:21:10 +00:00
# TRAINING CONFIG
def configure_optimizers(self):
# give lightning the list of optimizers you want to use.
# lightning will call automatically
2019-03-31 01:24:46 +00:00
optimizer = self.choose_optimizer('adam', self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer')
return [optimizer]
2019-04-03 16:40:03 +00:00
2019-03-31 01:21:10 +00:00
@property
def tng_dataloader(self):
return pytorch_dataloader('train')
@property
def val_dataloader(self):
return pytorch_dataloader('val')
@property
def test_dataloader(self):
return pytorch_dataloader('test')
2019-04-03 16:40:03 +00:00
2019-03-31 01:22:38 +00:00
# ---------------
2019-03-31 01:21:10 +00:00
# MODIFY YOUR COMMAND LINE ARGS
@staticmethod
def add_model_specific_args(parent_parser):
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
parser.add_argument('--out_features', default=20)
return parser
```
2019-03-31 20:47:15 +00:00
2019-03-31 01:25:43 +00:00
### Details
2019-03-31 00:50:32 +00:00
2019-03-31 01:25:43 +00:00
#### Model definition
| Name | Description | Input | Return |
|---|---|---|---|
| training_step | Called with a batch of data during training | data from your dataloaders | tuple: scalar, dict |
| validation_step | Called with a batch of data during validation | data from your dataloaders | tuple: scalar, dict |
| validation_end | Collate metrics from all validation steps | outputs: array where each item is the output of a validation step | dict: for logging |
2019-04-03 16:40:03 +00:00
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
2019-03-31 01:25:43 +00:00
#### Model training
| Name | Description | Input | Return |
|---|---|---|---|
| configure_optimizers | called during training setup | None | list: optimizers you want to use |
| tng_dataloader | called during training | None | pytorch dataloader |
| val_dataloader | called during validation | None | pytorch dataloader |
| test_dataloader | called during testing | None | pytorch dataloader |
| add_model_specific_args | called with args you defined in your main. This lets you tailor args for each model and keep main the same | argparse | argparse |
#### Model Saving/Loading
| Name | Description | Input | Return |
|---|---|---|---|
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
| load_model_specific | called when loading a model | checkpoint: dict you created in get_save_dict | dict: modified in whatever way you want |
2019-04-03 16:40:03 +00:00
2019-03-31 01:42:33 +00:00
## Optional model hooks.
Add these to the model whenever you want to configure training behavior.
2019-03-31 00:50:32 +00:00
### Model lifecycle hooks
Use these hooks to customize functionality
| Method | Purpose | Input | Output | Required |
|---|---|---|---|---|
| on_batch_start() | called right before the batch starts | - | - | N |
| on_batch_end() | called right after the batch ends | - | - | N |
| on_epoch_start() | called right before the epoch starts | - | - | N |
| on_epoch_end() | called right afger the epoch ends | - | - | N |
| on_pre_performance_check() | called right before the performance check starts | - | - | N |
| on_post_performance_check() | called right after the batch starts | - | - | N |