lightning/README.md

147 lines
4.2 KiB
Markdown
Raw Normal View History

2019-03-31 19:32:35 +00:00
<p align="center">
<a href="https://williamfalcon.github.io/pytorch-lightning/">
<img alt="" src="https://github.com/williamFalcon/pytorch-lightning/blob/master/docs/source/_static/lightning_logo.png" width="50">
2019-03-31 19:32:35 +00:00
</a>
</p>
<h3 align="center">
Pytorch Lightning
</h3>
<p align="center">
2019-04-01 16:38:31 +00:00
The Keras for ML researchers using PyTorch. More control. Less boilerplate.
2019-03-31 19:32:35 +00:00
</p>
<p align="center">
2019-03-31 20:59:24 +00:00
<a href="https://badge.fury.io/py/pytorch-lightning"><img src="https://badge.fury.io/py/pytorch-lightning.svg" alt="PyPI version" height="18"></a>
2019-03-31 20:59:39 +00:00
<!-- <a href="https://travis-ci.org/williamFalcon/test-tube"><img src="https://travis-ci.org/williamFalcon/pytorch-lightning.svg?branch=master"></a> -->
<a href="https://github.com/williamFalcon/pytorch-lightning/blob/master/COPYING"><img src="https://img.shields.io/badge/License-MIT-yellow.svg"></a>
2019-03-31 19:32:35 +00:00
</p>
```bash
pip install pytorch-lightning
```
2019-03-31 00:50:32 +00:00
2019-03-31 19:33:05 +00:00
## Docs
2019-06-26 23:18:41 +00:00
**[View the docs here](https://williamfalcon.github.io/pytorch-lightning/)**
2019-03-31 19:39:39 +00:00
## What is it?
2019-06-27 18:38:04 +00:00
Keras and fast.ai are too abstract for researchers. Lightning abstracts the full training loop but gives you control in the critical points.
2019-03-31 20:50:32 +00:00
2019-06-26 23:47:31 +00:00
## Why do I want to use lightning?
2019-06-27 18:38:04 +00:00
Because you want to use best practices and get gpu training, multi-node training, checkpointing, mixed-precision, etc... for free, but still want granular control of the meat of the training, validation and testing loops.
2019-06-26 23:47:31 +00:00
2019-06-26 23:58:33 +00:00
To use lightning do 2 things:
1. [Define a trainer](https://github.com/williamFalcon/pytorch-lightning/blob/master/docs/source/examples/basic_trainer.py) (which will run ALL your models).
2. [Define a model](https://github.com/williamFalcon/pytorch-lightning/blob/master/docs/source/examples/example_model.py).
2019-06-26 23:44:41 +00:00
## What are some key lightning features?
- Automatic training loop
```python
# define what happens for training here
2019-06-27 00:00:53 +00:00
def training_step(self, data_batch, batch_nb):
x, y = data_batch
out = self.forward(x)
loss = my_loss(out, y)
return {'loss': loss}
2019-06-26 23:44:41 +00:00
```
- Automatic validation loop
```python
# define what happens for validation here
2019-06-27 00:00:53 +00:00
def validation_step(self, data_batch, batch_nb): x, y = data_batch
out = self.forward(x)
loss = my_loss(out, y)
return {'loss': loss}
2019-06-26 23:44:41 +00:00
```
- Automatic early stopping
```python
callback = EarlyStopping(...)
Trainer(early_stopping=callback)
```
- Learning rate annealing
```python
# anneal at 100 and 200 epochs
Trainer(lr_scheduler_milestones=[100, 200])
```
2019-06-27 00:01:29 +00:00
- 16 bit precision training (must have apex installed)
2019-06-26 23:47:31 +00:00
```python
Trainer(use_amp=True, amp_level='O2')
```
2019-06-26 23:44:41 +00:00
- multi-gpu training
```python
# train on 4 gpus
Trainer(gpus=[0, 1, 2, 3])
```
- Automatic checkpointing
```python
# do 3 things:
# 1
Trainer(checkpoint_callback=ModelCheckpoint)
# 2 return what to save in a checkpoint
def get_save_dict(self):
return {'state_dict': self.state_dict()}
# 3 use the checkpoint to reset your model state
def load_model_specific(self, checkpoint):
self.load_state_dict(checkpoint['state_dict'])
```
2019-06-26 23:50:53 +00:00
- Log all details of your experiment (model params, code snapshot, etc...)
```python
from test_tube import Experiment
2019-06-26 23:44:41 +00:00
2019-06-26 23:50:53 +00:00
exp = Experiment(...)
Trainer(experiment=exp)
```
2019-03-31 19:39:39 +00:00
2019-06-26 23:58:33 +00:00
- Run grid-search on cluster
```python
from test_tube import Experiment, SlurmCluster, HyperOptArgumentParser
def training_fx(hparams, cluster, _):
# hparams are local params
model = MyModel()
trainer = Trainer(...)
trainer.fit(model)
# grid search number of layers
parser = HyperOptArgumentParser(strategy='grid_search')
parser.opt_list('--layers', default=5, type=int, options=[1, 5, 10, 20, 50])
hyperparams = parser.parse_args()
cluster = SlurmCluster(hyperparam_optimizer=hyperparams)
cluster.optimize_parallel_cluster_gpu(training_fx)
```
2019-03-31 01:47:51 +00:00
2019-06-27 00:02:51 +00:00
## Demo
2019-06-25 22:40:34 +00:00
```bash
# install lightning
pip install pytorch-lightning
# clone lightning for the demo
git clone https://github.com/williamFalcon/pytorch-lightning.git
cd pytorch-lightning/docs/source/examples
# run demo (on cpu)
python fully_featured_trainer.py
2019-06-25 22:47:11 +00:00
```
2019-06-25 22:40:34 +00:00
2019-06-25 22:47:11 +00:00
Without changing the model AT ALL, you can run the model on a single gpu, over multiple gpus, or over multiple nodes.
```bash
2019-06-25 22:44:11 +00:00
# run a grid search on two gpus
2019-06-25 22:40:34 +00:00
python fully_featured_trainer.py --gpus "0;1"
2019-06-25 22:44:11 +00:00
# run single model on multiple gpus
python fully_featured_trainer.py --gpus "0;1" --interactive
2019-06-25 22:40:34 +00:00
```
2019-03-31 01:21:10 +00:00