3.7 KiB
3.7 KiB
Trainer
The lightning trainer abstracts best practices for running a training, val, test routine. It calls parts of your model when it wants to hand over full control and otherwise makes training assumptions which are now standard practice in AI research.
This is the basic use of the trainer:
from pytorch_lightning import Trainer
model = LightningTemplate()
trainer = Trainer()
trainer.fit(model)
But of course the fun is in all the advanced things it can do:
Checkpointing
- Model saving
- Model loading
Computing cluster (SLURM)
Debugging
- Fast dev run
- Inspect gradient norms
- Log GPU usage
- Make model overfit on subset of data
- Print the parameter count by layer
- Pring which gradients are nan
Distributed training
Experiment Logging
- Display metrics in progress bar
- Log metric row every k batches
- Process position
- Tensorboard support
- Save a snapshot of all hyperparameters
- Snapshot code for a training run
- Write logs file to csv every k batches
Training loop
- Accumulate gradients
- Anneal Learning rate
- Force training for min or max epochs
- Force disable early stop
- Gradient Clipping
- Hooks
- Use multiple optimizers (like GANs)
- Set how much of the training set to check (1-100%)
Validation loop