Build and train PyTorch models and connect them to the ML lifecycle using Lightning App templates, without handling DIY infrastructure, cost management, scaling, and other headaches.
Go to file
William Falcon 92f5e026bb added docs page 2019-06-26 19:47:31 -04:00
docs added docs page 2019-06-26 19:18:41 -04:00
pytorch_lightning removed self.model refs 2019-06-26 18:27:25 -04:00
.gitignore updated args 2019-06-25 19:42:15 -04:00
COPYING Add src, docs and other important folders 2019-04-03 22:16:02 +05:30
MANIFEST.in Fix pip install too 2019-04-03 22:47:55 +05:30
README.md added docs page 2019-06-26 19:47:31 -04:00
mkdocs.yml Create mkdocs.yml 2019-06-26 18:54:07 -04:00
pyproject.toml Fix pip install too 2019-04-03 22:47:55 +05:30
requirements.txt initial commit 2019-03-30 20:50:32 -04:00
setup.cfg Fix pip install too 2019-04-03 22:47:55 +05:30
setup.py release v0.11 2019-06-26 18:44:59 -04:00
update.sh beta release to pypi 2019-03-31 15:26:23 -04:00

README.md

Pytorch Lightning

The Keras for ML researchers using PyTorch. More control. Less boilerplate.

PyPI version

pip install pytorch-lightning    

Docs

View the docs here

Disclaimer

This is a research tool I built for myself internally while doing my PhD. The API is not 100% production quality, but my hope is that by open-sourcing, we can all get it there (I don't have too much time nowadays to write production-level code).

What is it?

Keras is too abstract for researchers. Lightning abstracts the full training loop but gives you control in the critical points.

Why do I want to use lightning?

Because you want to use best practices and get gpu training, multi-node training, checkpointing, mixed-precision, etc... for free.

To use lightning do 2 things:

  1. Define a model with the lightning interface.
  2. Feed this model to the lightning trainer.

Example model definition

from pytorch_lightning import Trainer
from pytorch_lightning.utils.pt_callbacks import EarlyStopping, ModelCheckpoint

model = MyModel()

trainer = Trainer(
    checkpoint_callback=ModelCheckpoint(...),
    early_stop_callback=EarlyStopping(...),
    gpus=[0,1]
)

trainer.fit(model)

What are some key lightning features?

  • Automatic training loop
# define what happens for training here
def training_step(self, data_batch, batch_nb):  
  • Automatic validation loop
# define what happens for validation here
def validation_step(self, data_batch, batch_nb):
  • Automatic early stopping
callback = EarlyStopping(...)
Trainer(early_stopping=callback)
  • Learning rate annealing
# anneal at 100 and 200 epochs
Trainer(lr_scheduler_milestones=[100, 200])
  • 16 bit precision training
Trainer(use_amp=True, amp_level='O2')
  • multi-gpu training
# train on 4 gpus
Trainer(gpus=[0, 1, 2, 3])
  • Automatic checkpointing
# do 3 things:
# 1
Trainer(checkpoint_callback=ModelCheckpoint)

# 2 return what to save in a checkpoint
def get_save_dict(self):
    return {'state_dict': self.state_dict()}


# 3 use the checkpoint to reset your model state
def load_model_specific(self, checkpoint):
    self.load_state_dict(checkpoint['state_dict'])
  1. Learning rate annealing.
  2. Can train complex models like GANs or anything with multiple optimizers.
  3. Weight checkpointing.
  4. Model saving.
  5. Model loading.
  6. Log training details (through test-tube).
  7. Run training on multiple GPUs (through test-tube).
  8. Run training on a GPU cluster managed by SLURM (through test-tube).
  9. Distribute memory-bound models on multiple GPUs.
  10. Give your model hyperparameters parsed from the command line OR a JSON file.
  11. Run your model in a dev environment where nothing logs.

Usage

To use lightning do 2 things:

  1. Define a trainer (which will run ALL your models).
  2. Define a model.

Quick demo

Run the following demo to see how it works:

# install lightning
pip install pytorch-lightning

# clone lightning for the demo
git clone https://github.com/williamFalcon/pytorch-lightning.git
cd pytorch-lightning/docs/source/examples

# run demo (on cpu)
python fully_featured_trainer.py

Without changing the model AT ALL, you can run the model on a single gpu, over multiple gpus, or over multiple nodes.

# run a grid search on two gpus
python fully_featured_trainer.py --gpus "0;1"

# run single model on multiple gpus
python fully_featured_trainer.py --gpus "0;1" --interactive

Basic trainer example

See this demo for a more robust trainer example.

import os
import sys

from test_tube import HyperOptArgumentParser, Experiment
from pytorch_lightning.models.trainer import Trainer
from pytorch_lightning.utils.arg_parse import add_default_args
from pytorch_lightning.utils.pt_callbacks import EarlyStopping, ModelCheckpoint
from demo.example_model import ExampleModel


def main(hparams):
    """
    Main training routine specific for this project
    :param hparams:
    :return:
    """
    # init experiment
    exp = Experiment(
        name=hparams.tt_name,
        debug=hparams.debug,
        save_dir=hparams.tt_save_path,
        version=hparams.hpc_exp_number,
        autosave=False,
        description=hparams.tt_description
    )

    exp.argparse(hparams)
    exp.save()

    model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version)

    # build model
    model = ExampleModel(hparams)

    # callbacks
    early_stop = EarlyStopping(monitor='val_acc', patience=3, mode='min', verbose=True)
    checkpoint = ModelCheckpoint(filepath=model_save_path, save_function=None, save_best_only=True, verbose=True, monitor='val_acc', mode='min')

    # configure trainer
    trainer = Trainer(experiment=exp, checkpoint_callback=checkpoint, early_stop_callback=early_stop)

    # train model
    trainer.fit(model)


if __name__ == '__main__':

    # use default args given by lightning
    root_dir = os.path.split(os.path.dirname(sys.modules['__main__'].__file__))[0]
    parent_parser = HyperOptArgumentParser(strategy='random_search', add_help=False)
    add_default_args(parent_parser, root_dir)

    # allow model to overwrite or extend args
    parser = ExampleModel.add_model_specific_args(parent_parser)
    hyperparams = parser.parse_args()

    # train model
    main(hyperparams)

Basic model example

Here we only show the method signatures. It's up to you to define the content.

from torch import nn

class My_Model(RootModule):
    def __init__(self):
        # define model
        self.l1 = nn.Linear(200, 10)

    # ---------------
    # TRAINING
    def training_step(self, data_batch):
        x, y = data_batch
        y_hat = self.l1(x)
        loss = some_loss(y_hat)

        return loss_val, {'train_loss': loss}

    def validation_step(self, data_batch):
        x, y = data_batch
        y_hat = self.l1(x)
        loss = some_loss(y_hat)

        return loss_val, {'val_loss': loss}

     def validation_end(self, outputs):
        total_accs = []

        for output in outputs:
            total_accs.append(output['val_acc'].item())

        # return a dict
        return {'total_acc': np.mean(total_accs)}

     # ---------------
     # SAVING
     def get_save_dict(self):
        # lightning saves for you. Here's your chance to say what you want to save
        checkpoint = {'state_dict': self.state_dict()}

        return checkpoint

    def load_model_specific(self, checkpoint):
        # lightning loads for you. Here's your chance to say what you want to load
        self.load_state_dict(checkpoint['state_dict'])

    # ---------------
    # TRAINING CONFIG
    def configure_optimizers(self):
        # give lightning the list of optimizers you want to use.
        # lightning will call automatically
        optimizer = self.choose_optimizer('adam', self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer')
        return [optimizer]

    @property
    def tng_dataloader(self):
        return pytorch_dataloader('train')

    @property
    def val_dataloader(self):
        return pytorch_dataloader('val')

    @property
    def test_dataloader(self):
        return pytorch_dataloader('test')

    # ---------------
    # MODIFY YOUR COMMAND LINE ARGS
    @staticmethod
    def add_model_specific_args(parent_parser):    
        parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
        parser.add_argument('--out_features', default=20)
        return parser

Details

Model definition

Name Description Input Return
training_step Called with a batch of data during training data from your dataloaders tuple: scalar, dict
validation_step Called with a batch of data during validation data from your dataloaders tuple: scalar, dict
validation_end Collate metrics from all validation steps outputs: array where each item is the output of a validation step dict: for logging
get_save_dict called when your model needs to be saved (checkpoints, hpc save, etc...) None dict to be saved

Model training

Name Description Input Return
configure_optimizers called during training setup None list: optimizers you want to use
tng_dataloader called during training None pytorch dataloader
val_dataloader called during validation None pytorch dataloader
test_dataloader called during testing None pytorch dataloader
add_model_specific_args called with args you defined in your main. This lets you tailor args for each model and keep main the same argparse argparse

Model Saving/Loading

Name Description Input Return
get_save_dict called when your model needs to be saved (checkpoints, hpc save, etc...) None dict to be saved
load_model_specific called when loading a model checkpoint: dict you created in get_save_dict dict: modified in whatever way you want

Optional model hooks.

Add these to the model whenever you want to configure training behavior.

Model lifecycle hooks

Use these hooks to customize functionality

Method Purpose Input Output Required
on_batch_start() called right before the batch starts - - N
on_batch_end() called right after the batch ends - - N
on_epoch_start() called right before the epoch starts - - N
on_epoch_end() called right afger the epoch ends - - N
on_pre_performance_check() called right before the performance check starts - - N
on_post_performance_check() called right after the batch starts - - N