6.2 KiB
6.2 KiB
Trainer
The lightning trainer abstracts best practices for running a training, val, test routine. It calls parts of your model when it wants to hand over full control and otherwise makes training assumptions which are now standard practice in AI research.
This is the basic use of the trainer:
from pytorch_lightning import Trainer
model = LightningTemplate()
trainer = Trainer()
trainer.fit(model)
But of course the fun is in all the advanced things it can do:
Checkpointing
Computing cluster (SLURM)
Debugging
- Fast dev run
- Inspect gradient norms
- Log GPU usage
- Make model overfit on subset of data
- Print the parameter count by layer
- Print which gradients are nan
- Print input and output size of every module in system
Distributed training
Experiment Logging
- Display metrics in progress bar
- Log metric row every k batches
- Process position
- Tensorboard support
- Save a snapshot of all hyperparameters
- Snapshot code for a training run
- Write logs file to csv every k batches
Training loop
- Accumulate gradients
- Force training for min or max epochs
- Early stopping callback
- Force disable early stop
- Gradient Clipping
- Hooks
- Learning rate scheduling
- Use multiple optimizers (like GANs)
- Set how much of the training set to check (1-100%)
- Step optimizers at arbitrary intervals
Validation loop
- Check validation every n epochs
- Hooks
- Set how much of the validation set to check
- Set how much of the test set to check
- Set validation check frequency within 1 training epoch
- Set the number of validation sanity steps
Testing loop