Build and train PyTorch models and connect them to the ML lifecycle using Lightning App templates, without handling DIY infrastructure, cost management, scaling, and other headaches.
Go to file
Ananya Harsh Jha c0f3b6b035 added set_epoch for distributed sampler, fix for #224 (#225) 2019-09-16 10:21:00 -04:00
.github add PR template (#204) 2019-09-06 10:12:06 -04:00
docs Simplified gpu api. No NVIDIA flag managing by lightning for cluster (#213) 2019-09-08 15:36:58 -04:00
examples Update README.md 2019-09-14 09:55:42 -04:00
pytorch_lightning added set_epoch for distributed sampler, fix for #224 (#225) 2019-09-16 10:21:00 -04:00
tests added load on CPU first (#221) 2019-09-11 07:52:36 -04:00
.codecov.yml add Codecov info (#144) 2019-08-19 06:35:09 -04:00
.gitignore ommit templates folder 2019-08-14 08:59:05 -04:00
.readthedocs.yml pkg relative imports 2019-08-05 10:52:09 +02:00
.travis.yml add osx to Travis (#202) 2019-09-05 15:08:19 -04:00
LICENSE MIT -> apache 2 license 2019-08-06 22:45:45 +02:00
MANIFEST.in Updated distributed Demos (#215) 2019-09-08 18:17:33 -04:00
README.md Update README.md 2019-09-14 02:18:33 -04:00
appveyor.yml update Win CI req. (#123) 2019-08-15 11:45:03 -04:00
mkdocs.yml docs: add repo_name in the upright corner (#171) 2019-08-27 16:46:18 -04:00
pyproject.toml Fix pip install too 2019-04-03 22:47:55 +05:30
requirements.txt fix appveyor - install pytorch 2019-08-07 11:07:22 +02:00
setup.cfg added gan template (#115) 2019-08-14 08:38:49 -04:00
setup.py release v0.4.8 2019-09-02 07:15:45 -04:00
tox.ini enable single gpu per node (#218) 2019-09-09 07:37:20 -04:00
update.sh add CircleCI 2019-08-06 22:45:46 +02:00

README.md

Logo

PyTorch Lightning

The lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate.

PyPI Status PyPI Status Build Status Build status Coverage CodeFactor

ReadTheDocs Gitter license

Simple installation from PyPI

pip install pytorch-lightning  

Docs

View the docs here

What is it?

Lightning is a very lightweight wrapper on PyTorch. This means you don't have to learn a new library. To use Lightning, simply refactor your research code into the LightningModule format and Lightning will automate the rest. Lightning guarantees tested, correct, modern best practices for the automated parts.

Starting a new project?

Use our seed-project aimed at reproducibility!

Why do I want to use lightning?

Every research project starts the same, a model, a training loop, validation loop, etc. As your research advances, you're likely to need distributed training, 16-bit precision, checkpointing, gradient accumulation, etc.

Lightning sets up all the boilerplate state-of-the-art training for you so you can focus on the research.


README Table of Contents


How do I do use it?

Think about Lightning as refactoring your research code instead of using a new framework. The research code goes into a LightningModule which you fit using a Trainer.

The LightningModule defines a system such as seq-2-seq, GAN, etc... It can ALSO define a simple classifier such as the example below.

To use lightning do 2 things:

  1. Define a LightningModule
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

import pytorch_lightning as pl

class CoolSystem(pl.LightningModule):

    def __init__(self):
        super(CoolSystem, self).__init__()
        # not the best model...
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        # REQUIRED
        x, y = batch
        y_hat = self.forward(x)
        return {'loss': F.cross_entropy(y_hat, y)}

    def validation_step(self, batch, batch_nb):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'avg_val_loss': avg_loss}

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        return torch.optim.Adam(self.parameters(), lr=0.02)

    @pl.data_loader
    def tng_dataloader(self):
        # REQUIRED
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

    @pl.data_loader
    def val_dataloader(self):
        # OPTIONAL
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

    @pl.data_loader
    def test_dataloader(self):
        # OPTIONAL
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
  1. Fit with a trainer
from pytorch_lightning import Trainer

model = CoolSystem()

# most basic trainer, uses good defaults
trainer = Trainer()    
trainer.fit(model)   

Or with tensorboard logger and some options turned on such as multi-gpu, etc...

from test_tube import Experiment    

# PyTorch summarywriter with a few bells and whistles    
exp = Experiment(save_dir=os.getcwd())

# train on cpu using only 10% of the data (for demo purposes)
# pass in experiment for automatic tensorboard logging.    
trainer = Trainer(experiment=exp, max_nb_epochs=1, train_percent_check=0.1)

# train on 4 gpus
# trainer = Trainer(experiment=exp, max_nb_epochs=1, gpus=[0, 1, 2, 3])

# train on 32 gpus across 4 nodes (make sure to submit appropriate SLURM job)
# trainer = Trainer(experiment=exp, max_nb_epochs=1, gpus=[0, 1, 2, 3, 4, 5, 6, 7], nb_gpu_nodes=4)

# train (1 epoch only here for demo)
trainer.fit(model)

# view tensorflow logs 
print('View tensorboard logs by running\ntensorboard --logdir %s' % os.getcwd())
print('and going to http://localhost:6006 on your browser')

When you're all done you can even run the test set separately.

trainer.test()

What does lightning control for me?

Everything in gray!
You define the blue parts using the LightningModule interface:

Ouverview

# what to do in the training loop
def training_step(self, data_batch, batch_nb):

# what to do in the validation loop
def validation_step(self, data_batch, batch_nb):

# how to aggregate validation_step outputs
def validation_end(self, outputs):

# and your dataloaders
def tng_dataloader():
def val_dataloader():
def test_dataloader():

Could be as complex as seq-2-seq + attention

# define what happens for training here
def training_step(self, data_batch, batch_nb):
    x, y = data_batch
    
    # define your own forward and loss calculation
    hidden_states = self.encoder(x)
     
    # even as complex as a seq-2-seq + attn model
    # (this is just a toy, non-working example to illustrate)
    start_token = '<SOS>'
    last_hidden = torch.zeros(...)
    loss = 0
    for step in range(max_seq_len):
        attn_context = self.attention_nn(hidden_states, start_token)
        pred = self.decoder(start_token, attn_context, last_hidden) 
        last_hidden = pred
        pred = self.predict_nn(pred)
        loss += self.loss(last_hidden, y[step])
        
    #toy example as well
    loss = loss / max_seq_len
    return {'loss': loss} 

Or as basic as CNN image classification

# define what happens for validation here
def validation_step(self, data_batch, batch_nb):    
    x, y = data_batch
    
    # or as basic as a CNN classification
    out = self.forward(x)
    loss = my_loss(out, y)
    return {'loss': loss} 

And you also decide how to collate the output of all validation steps

def validation_end(self, outputs):
    """
    Called at the end of validation to aggregate outputs
    :param outputs: list of individual outputs of each validation step
    :return:
    """
    val_loss_mean = 0
    val_acc_mean = 0
    for output in outputs:
        val_loss_mean += output['val_loss']
        val_acc_mean += output['val_acc']

    val_loss_mean /= len(outputs)
    val_acc_mean /= len(outputs)
    tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
    return tqdm_dic

Tensorboard

Lightning is fully integrated with tensorboard.

tensorboard-support

Lightning also adds a text column with all the hyperparameters for this experiment.

tensorboard-support

Simply note the path you set for the Experiment from test_tube

from test_tube import Experiment
from pytorch_lightning import Trainer

exp = Experiment(save_dir='/some/path')
trainer = Trainer(experiment=exp)
...

And run tensorboard from that dir

tensorboard --logdir /some/path     

Lightning automates all of the following (each is also configurable):

Checkpointing

Computing cluster (SLURM)

Debugging

Distributed training

Experiment Logging

Training loop

Validation loop

Testing loop

Demo

# install lightning
pip install pytorch_lightning

# clone lightning for the demo
git clone https://github.com/williamFalcon/pytorch-lightning.git
cd pytorch-lightning
cd examples/new_project_templates/

# all of the following demos use the SAME model to show no modification needs to be made to your code

# train on cpu 
python single_cpu_template.py

# train on multiple-gpus 
python single_gpu_node_template.py --gpus "0,1"

# train on 32 gpus on a cluster (run on a SLURM managed cluster)
python multi_node_cluster_template.py --nb_gpu_nodes 4 --gpus '0,1,2,3,4,5,6,7'

Tutorials


Asking for help

Welcome to the Lightning community!

If you have any questions, feel free to:

  1. read the docs.
  2. Search through the issues.
  3. Ask on stackoverflow with the tag pytorch-lightning.

If no one replies to you quickly enough, feel free to post the stackoverflow link to our Gitter chat!

To chat with the rest of us visit our gitter channel!


FAQ

How do I use Lightning for rapid research?
Here's a walk-through

Why was Lightning created?
Lightning has 3 goals in mind:

  1. Maximal flexibility while abstracting out the common boilerplate across research projects.
  2. Reproducibility. If all projects use the LightningModule template, it will be much much easier to understand what's going on and where to look! It will also mean every implementation follows a standard format.
  3. Democratizing PyTorch power user features. Distributed training? 16-bit? know you need them but don't want to take the time to implement? All good... these come built into Lightning.

How does Lightning compare with Ignite and fast.ai?
Here's a thorough comparison.

Is this another library I have to learn?
Nope! We use pure Pytorch everywhere and don't add unecessary abstractions!

Are there plans to support Python 2?
Nope.

Are there plans to support virtualenv?
Nope. Please use anaconda or miniconda.

Which PyTorch versions do you support?

  • PyTorch 1.1.0
    # install pytorch 1.1.0 using the official instructions   
    
    # install test-tube 0.6.7.6 which supports 1.1.0   
    pip install test-tube==0.6.7.6   
    
    # install latest Lightning version without upgrading deps    
    pip install -U --no-deps pytorch-lightning
    
  • PyTorch 1.2.0 Install via pip as normal

Custom installation

Bleeding edge

If you can't wait for the next release, install the most up to date code with:

  • using GIT (locally clone whole repo with full history)
    pip install git+https://github.com/williamFalcon/pytorch-lightning.git@master --upgrade
    
  • using instant zip (last state of the repo without git history)
    pip install https://github.com/williamFalcon/pytorch-lightning/archive/master.zip --upgrade
    

Any release installation

You can also install any past release from this repository:

pip install https://github.com/williamFalcon/pytorch-lightning/archive/0.4.4.zip --upgrade