lightning/README.md

6.1 KiB

Pytorch Lightning

The Keras for ML researchers using PyTorch. More control. Less boilerplate.

PyPI version

pip install pytorch-lightning    

Docs

View the docs here

What is it?

Keras and fast.ai are too abstract for researchers. Lightning abstracts the full training loop but gives you control in the critical points.

Why do I want to use lightning?

Because you want to use best practices and get gpu training, multi-node training, checkpointing, mixed-precision, etc... for free, but still want granular control of the meat of the training, validation and testing loops.

To use lightning do 2 things:

  1. Define a Trainer.
  2. Define a LightningModel.

What does lightning control for me?

Everything! Except the following three things:

Automatic training loop

# define what happens for training here
def training_step(self, data_batch, batch_nb):
    x, y = data_batch
    
    # define your own forward and loss calculation
    out = self.forward(x)
    loss = my_loss(out, y)
    return {'loss': loss} 

Automatic validation loop

# define what happens for validation here
def validation_step(self, data_batch, batch_nb):    
    x, y = data_batch
    
    # define your own forward and loss calculation
    out = self.forward(x)
    loss = my_loss(out, y)
    return {'loss': loss} 

Collate the output of the validation_step

def validation_end(self, outputs):
    """
    Called at the end of validation to aggregate outputs
    :param outputs: list of individual outputs of each validation step
    :return:
    """
    val_loss_mean = 0
    val_acc_mean = 0
    for output in outputs:
        val_loss_mean += output['val_loss']
        val_acc_mean += output['val_acc']

    val_loss_mean /= len(outputs)
    val_acc_mean /= len(outputs)
    tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
    return tqdm_dic

Lightning gives you options to control the following:

Checkpointing

  • Model saving
  • Model loading

Computing cluster (SLURM)

  • Automatic checkpointing
  • Automatic saving, loading
  • Running grid search on a cluster
  • Walltime auto-resubmit

Debugging

Distributed training

Experiment Logging

Training loop

Validation loop

Demo

# install lightning
pip install pytorch-lightning

# clone lightning for the demo
git clone https://github.com/williamFalcon/pytorch-lightning.git
cd examples/new_project_templates/

# run demo (on cpu)
python trainer_gpu_cluster_template.py

Without changing the model AT ALL, you can run the model on a single gpu, over multiple gpus, or over multiple nodes.

# run a grid search on two gpus
python fully_featured_trainer.py --gpus "0;1"

# run single model on multiple gpus
python fully_featured_trainer.py --gpus "0;1" --interactive