7.8 KiB
Pytorch Lightning
The Keras for ML researchers using PyTorch. More control. Less boilerplate.
pip install pytorch-lightning
Docs
What is it?
Keras and fast.ai are too abstract for researchers. Lightning abstracts the full training loop but gives you control in the critical points.
Why do I want to use lightning?
Because you don't want to define a training loop, validation loop, gradient clipping, checkpointing, loading, gpu training, etc... every time you start a project. Let lightning handle all of that for you! Just define your data and what happens in the training, testing and validation loop and lightning will do the rest.
To use lightning do 2 things:
What does lightning control for me?
Everything! Except the following three things:
What happens in the training loop
# define what happens for training here
def training_step(self, data_batch, batch_nb):
x, y = data_batch
# define your own forward and loss calculation
out = self.forward(x)
loss = my_loss(out, y)
return {'loss': loss}
What happens in the validation loop
# define what happens for validation here
def validation_step(self, data_batch, batch_nb):
x, y = data_batch
# define your own forward and loss calculation
out = self.forward(x)
loss = my_loss(out, y)
return {'loss': loss}
And what to do with the output of all validation batches
def validation_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
:param outputs: list of individual outputs of each validation step
:return:
"""
val_loss_mean = 0
val_acc_mean = 0
for output in outputs:
val_loss_mean += output['val_loss']
val_acc_mean += output['val_acc']
val_loss_mean /= len(outputs)
val_acc_mean /= len(outputs)
tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
return tqdm_dic
Lightning gives you options to control the following:
Checkpointing
- Model saving
- Model loading
Computing cluster (SLURM)
- Automatic checkpointing
- Automatic saving, loading
- Running grid search on a cluster
- Walltime auto-resubmit
Debugging
- Fast dev run
- Inspect gradient norms
- Log GPU usage
- Make model overfit on subset of data
- Print the parameter count by layer
- Pring which gradients are nan
Distributed training
Experiment Logging
- Display metrics in progress bar
- Log arbitrary metrics
- Log metric row every k batches
- Process position
- Save a snapshot of all hyperparameters
- Snapshot code for a training run
- Write logs file to csv every k batches
Training loop
- Accumulate gradients
- Anneal Learning rate
- Force training for min or max epochs
- Force disable early stop
- Use multiple optimizers (like GANs)
- Set how much of the training set to check (1-100%)
Validation loop
- Check validation every n epochs
- Set how much of the validation set to check
- Set how much of the test set to check
- Set validation check frequency within 1 training epoch
- Set the number of validation sanity steps
Demo
# install lightning
pip install pytorch-lightning
# clone lightning for the demo
git clone https://github.com/williamFalcon/pytorch-lightning.git
cd examples/new_project_templates/
# run demo (on cpu)
python trainer_gpu_cluster_template.py
Without changing the model AT ALL, you can run the model on a single gpu, over multiple gpus, or over multiple nodes.
# run a grid search on two gpus
python fully_featured_trainer.py --gpus "0;1"
# run single model on multiple gpus
python fully_featured_trainer.py --gpus "0;1" --interactive