8.6 KiB
8.6 KiB
New project Quick Start
To start a new project define two files, a LightningModule and a Trainer file.
To illustrate Lightning power and simplicity, here's an example of a typical research flow.
Case 1: BERT
Let's say you're working on something like BERT but want to try different ways of training or even different networks.
You would define a single LightningModule and use flags to switch between your different ideas.
class BERT(pl.LightningModule):
def __init__(self, model_name, task):
self.task = task
if model_name == 'transformer':
self.net = Transformer()
elif model_name == 'my_cool_version':
self.net = MyCoolVersion()
def training_step(self, batch, batch_nb):
if self.task == 'standard_bert':
# do standard bert training with self.net...
# return loss
if self.task == 'my_cool_task':
# do my own version with self.net
# return loss
Case 2: COOLER NOT BERT
But if you wanted to try something completely different, you'd define a new module for that.
class CoolerNotBERT(pl.LightningModule):
def __init__(self):
self.net = ...
def training_step(self, batch, batch_nb):
# do some other cool task
# return loss
Rapid research flow
Then you could do rapid research by switching between these two and using the same trainer.
if use_bert:
model = BERT()
else:
model = CoolerNotBERT()
trainer = Trainer(gpus=4, use_amp=True)
trainer.fit(model)
Notice a few things about this flow:
- You're writing pure PyTorch... no unnecessary abstractions or new libraries to learn.
- You get free GPU and 16-bit support without writing any of that code in your model.
- You also get all of the capabilities below (without coding or testing yourself).
Templates
Docs shortcuts
Quick start examples
- CPU example
- Hyperparameter search on single GPU
- Hyperparameter search on multiple GPUs on same node
- [Hyperparameter search on a SLURM HPC cluster](examples/Examples/#Hyperparameter search on a SLURM HPC cluster)
Checkpointing
Computing cluster (SLURM)
Debugging
- Fast dev run
- Inspect gradient norms
- Log GPU usage
- Make model overfit on subset of data
- Print the parameter count by layer
- Pring which gradients are nan
- Print input and output size of every module in system
Distributed training
Experiment Logging
- Display metrics in progress bar
- Log metric row every k batches
- Process position
- Tensorboard support
- Save a snapshot of all hyperparameters
- Snapshot code for a training run
- Write logs file to csv every k batches
Training loop
- Accumulate gradients
- Force training for min or max epochs
- Early stopping callback
- Force disable early stop
- Gradient Clipping
- Hooks
- Learning rate scheduling
- Use multiple optimizers (like GANs)
- Set how much of the training set to check (1-100%)
- Step optimizers at arbitrary intervals
Validation loop
- Check validation every n epochs
- Hooks
- Set how much of the validation set to check
- Set how much of the test set to check
- Set validation check frequency within 1 training epoch
- Set the number of validation sanity steps