lightning/README.md

13 KiB

Pytorch Lightning

The Keras for ML researchers using PyTorch. More control. Less boilerplate.

PyPI version PyPI version

pip install pytorch-lightning  

Docs

View the docs here

What is it?

Lightning defers training and validation loop logic to you. It guarantees correct, modern best practices for the core training logic.

Why do I want to use lightning?

When starting a new project the last thing you want to do is recode a training loop, model loading/saving, distributed training, when to validate, etc... You're likely to spend a long time ironing out all the bugs without even getting to the core of your research.

With lightning, you guarantee those parts of your code work so you can focus on what the meat of the research: Data and training, validation loop logic. Don't worry about multiple gpus or speeding up your code, lightning will do that for you!

How do I do use it?

To use lightning do 2 things:

  1. Define a LightningModel
import pytorch_lightning as ptl
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

class CoolModel(ptl.LightningModule):

    def __init(self):
        super(CoolModel, self).__init__()
        # not the best model...
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x))

    def my_loss(self, y_hat, y):
        return F.cross_entropy(y_hat, y)

    def training_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self.forward(x)
        return {'tng_loss': self.my_loss(y_hat, y)}

    def validation_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': self.my_loss(y_hat, y)}

    def validation_end(self, outputs):
        avg_loss = torch.stack([x for x in outputs['val_loss']]).mean()
        return avg_loss

    def configure_optimizers(self):
        return [torch.optim.Adam(self.parameters(), lr=0.02)]

    @ptl.data_loader
    def tng_dataloader(self):
        return DataLoader(MNIST('path/to/save', train=True), batch_size=32)

    @ptl.data_loader
    def val_dataloader(self):
        return DataLoader(MNIST('path/to/save', train=False), batch_size=32)

    @ptl.data_loader
    def test_dataloader(self):
        return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
  1. Fit with a trainer
from pytorch_lightning import Trainer
from test_tube import Experiment

model = CoolModel()

# fit on 32 gpus across 4 nodes
exp = Experiment(save_dir='some/dir')
trainer = Trainer(experiment=exp, nb_gpu_nodes=4, gpus=[0,1,2,3,4,5,6,7])

trainer.fit(model)

# see all experiment metrics here
# tensorboard --log_dir some/dir   

What does lightning control for me?

Everything!
Except for these 6 core functions which you define:

# what to do in the training loop
def training_step(self, data_batch, batch_nb):

# what to do in the validation loop
def validation_step(self, data_batch, batch_nb):

# how to aggregate validation_step outputs
def validation_end(self, outputs):

# and your dataloaders
def tng_dataloader():
def val_dataloader():
def test_dataloader():

Could be as complex as seq-2-seq + attention

# define what happens for training here
def training_step(self, data_batch, batch_nb):
    x, y = data_batch
    
    # define your own forward and loss calculation
    hidden_states = self.encoder(x)
     
    # even as complex as a seq-2seq + attn model
    # (this is just a toy, non-working example to illustrate)
    start_token = '<SOS>'
    last_hidden = torch.zeros(...)
    loss = 0
    for step in range(max_seq_len):
        attn_context = self.attention_nn(hidden_states, start_token)
        pred = self.decoder(start_token, attn_context, last_hidden) 
        last_hidden = pred
        pred = self.predict_nn(pred)
        loss += self.loss(last_hidden, y[step])
        
    #toy example as well
    loss = loss / max_seq_len
    return {'loss': loss} 

Or as basic as CNN image classification

# define what happens for validation here
def validation_step(self, data_batch, batch_nb):    
    x, y = data_batch
    
    # or as basic as a CNN classification
    out = self.forward(x)
    loss = my_loss(out, y)
    return {'loss': loss} 

And you also decide how to collate the output of all validation steps

def validation_end(self, outputs):
    """
    Called at the end of validation to aggregate outputs
    :param outputs: list of individual outputs of each validation step
    :return:
    """
    val_loss_mean = 0
    val_acc_mean = 0
    for output in outputs:
        val_loss_mean += output['val_loss']
        val_acc_mean += output['val_acc']

    val_loss_mean /= len(outputs)
    val_acc_mean /= len(outputs)
    tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
    return tqdm_dic

Tensorboard

Lightning is fully integrated with tensorboard.

Lightning also adds a text column with all the hyperparameters for this experiment.

Simply note the path you set for the Experiment

from test_tube import Experiment
from pytorch-lightning import  Trainer

exp = Experiment(save_dir='/some/path')
trainer = Trainer(experiment=exp)
...

And run tensorboard from that dir

tensorboard --logdir /some/path     

Lightning automates all of the following (each is also configurable):

Checkpointing
Computing cluster (SLURM)
Debugging
Distributed training
Experiment Logging
Training loop
Validation loop

Demo

# install lightning
pip install pytorch-lightning

# clone lightning for the demo
git clone https://github.com/williamFalcon/pytorch-lightning.git
cd pytorch_lightning/examples/new_project_templates/

# all of the following demos use the SAME model to show no modification needs to be made to your code

# train on cpu 
python single_cpu_template.py

# train on multiple-gpus 
python single_gpu_node_template.py --gpus "0,1"

# train on 32 gpus on a cluster (run on a SLURM managed cluster)
python multi_node_cluster_template.py --nb_gpu_nodes 4 --gpus '0,1,2,3,4,5,6,7'

Bleeding edge

If you can't wait for the next release, install the most up to date code with:

pip install git+https://github.com/williamFalcon/pytorch-lightning.git@master --upgrade