22 KiB
Lightning Module interface
A lightning module is a strict superclass of nn.Module, it provides a standard interface for the trainer to interact with the model.
The easiest thing to do is copy the minimal example below and modify accordingly.
Otherwise, to Define a Lightning Module, implement the following methods:
Required:
Optional:
- validation_step
- validation_end
- test_step
- test_end
- val_dataloader
- test_dataloader
- on_save_checkpoint
- on_load_checkpoint
- add_model_specific_args
Minimal example
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
import pytorch_lightning as pl
class CoolModel(pl.LightningModule):
def __init__(self):
super(CoolModel, self).__init__()
# not the best model...
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb):
# REQUIRED
x, y = batch
y_hat = self.forward(x)
return {'loss': F.cross_entropy(y_hat, y)}
def validation_step(self, batch, batch_nb):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
return {'val_loss': F.cross_entropy(y_hat, y)}
def validation_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'avg_val_loss': avg_loss}
def test_step(self, batch, batch_nb):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
return {'test_loss': F.cross_entropy(y_hat, y)}
def test_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
return {'avg_test_loss': avg_loss}
def configure_optimizers(self):
# REQUIRED
return torch.optim.Adam(self.parameters(), lr=0.02)
@pl.data_loader
def train_dataloader(self):
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
@pl.data_loader
def val_dataloader(self):
# OPTIONAL
# can also return a list of val dataloaders
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
@pl.data_loader
def test_dataloader(self):
# OPTIONAL
# can also return a list of test dataloaders
return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)
How do these methods fit into the broader training?
The LightningModule interface is on the right. Each method corresponds to a part of a research project. Lightning automates everything not in blue.
Required Methods
training_step
def training_step(self, batch, batch_nb)
In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something specific to your model.
Params
Param | description |
---|---|
batch | The output of your dataloader. A tensor, tuple or list |
batch_nb | Integer displaying which batch this is |
Return
Dictionary or OrderedDict
key | value | is required |
---|---|---|
loss | tensor scalar | Y |
progress_bar | Dict for progress bar display. Must have only tensors | N |
log | Dict of metrics to add to logger. Must have only tensors (no images, etc) | N |
Example
def training_step(self, batch, batch_nb):
x, y, z = batch
# implement your own
out = self.forward(x)
loss = self.loss(out, x)
output = {
'loss': loss, # required
'progress_bar': {'training_loss': loss}, # optional (MUST ALL BE TENSORS)
'log': {'training_loss': loss} # optional (MUST ALL BE TENSORS)
}
# return a dict
return output
If you define multiple optimizers, this step will also be called with an additional optimizer_idx
param.
# Multiple optimizers (ie: GANs)
def training_step(self, batch, batch_nb, optimizer_idx):
if optimizer_idx == 0:
# do training_step with encoder
if optimizer_idx == 1:
# do training_step with decoder
You can also return a -1 instead of a dict to stop the current loop. This is useful if you want to break out of the current training epoch early.
train_dataloader
@pl.data_loader
def train_dataloader(self)
Called by lightning during training loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
If you want to change the data during every epoch DON'T use the data_loader decorator.
Return
PyTorch DataLoader
Example
@pl.data_loader
def train_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root='/path/to/mnist/', train=True, transform=transform, download=True)
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True
)
return loader
configure_optimizers
def configure_optimizers(self)
Set up as many optimizers and (optionally) learning rate schedulers as you need. Normally you'd need one. But in the case of GANs or something more esoteric you might have multiple. Lightning will call .backward() and .step() on each one in every epoch. If you use 16 bit precision it will also handle that.
Note: If you use multiple optimizers, training_step will have an additional optimizer_idx
parameter.
Note 2: If you use LBFGS lightning handles the closure function automatically for you.
Return
Return any of these 3 options:
Single optimizer
List or Tuple - List of optimizers
Two lists - The first list has multiple optimizers, the second a list of learning-rate schedulers
Example
# most cases
def configure_optimizers(self):
opt = Adam(self.parameters(), lr=0.01)
return opt
# multiple optimizer case (eg: GAN)
def configure_optimizers(self):
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
return generator_opt, disriminator_opt
# example with learning_rate schedulers
def configure_optimizers(self):
generator_opt = Adam(self.model_gen.parameters(), lr=0.01)
disriminator_opt = Adam(self.model_disc.parameters(), lr=0.02)
discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
return [generator_opt, disriminator_opt], [discriminator_sched]
If you need to control how often those optimizers step or override the default .step() schedule, override the optimizer_step hook.
Optional Methods
validation_step
# if you have one val dataloader:
def validation_step(self, batch, batch_nb)
# if you have multiple val dataloaders:
def validation_step(self, batch, batch_nb, dataloader_idxdx)
OPTIONAL
If you don't need to validate you don't need to implement this method. In this step you'd normally generate examples or calculate anything of interest such as accuracy.
When the validation_step is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, model goes back to training mode and gradients are enabled.
The dict you return here will be available in the validation_end
method.
Params
Param | description |
---|---|
batch | The output of your dataloader. A tensor, tuple or list |
batch_nb | Integer displaying which batch this is |
dataloader_idx | Integer displaying which dataloader this is (only if multiple val datasets used) |
Return
Return | description | optional |
---|---|---|
dict | Dict or OrderedDict - passed to the validation_end step | N |
Example
# CASE 1: A single validation dataset
def validation_step(self, batch, batch_nb):
x, y = batch
# implement your own
out = self.forward(x)
loss = self.loss(out, y)
# log 6 example images
# or generated text... or whatever
sample_imgs = x[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('example_images', grid, 0)
# calculate acc
labels_hat = torch.argmax(out, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
# all optional...
# return whatever you need for the collation function validation_end
output = OrderedDict({
'val_loss': loss_val,
'val_acc': torch.tensor(val_acc), # everything must be a tensor
})
# return an optional dict
return output
If you pass in multiple validation datasets, validation_step will have an additional argument.
# CASE 2: multiple validation datasets
def validation_step(self, batch, batch_nb, dataset_idx):
# dataset_idx tells you which dataset this is.
The dataset_idx
corresponds to the order of datasets returned in val_dataloader
.
validation_end
def validation_end(self, outputs)
If you didn't define a validation_step, this won't be called.
Called at the end of the validation loop with the outputs of validation_step.
The outputs here are strictly for the progress bar. If you don't need to display anything, don't return anything.
Params
Param | description |
---|---|
outputs | List of outputs you defined in validation_step, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader |
Return
Dictionary or OrderedDict
key | value | is required |
---|---|---|
progress_bar | Dict for progress bar display. Must have only tensors | N |
log | Dict of metrics to add to logger. Must have only tensors (no images, etc) | N |
Example
With a single dataloader
def validation_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
:param outputs: list of individual outputs of each validation step
:return:
"""
val_loss_mean = 0
val_acc_mean = 0
for output in outputs:
val_loss_mean += output['val_loss']
val_acc_mean += output['val_acc']
val_loss_mean /= len(outputs)
val_acc_mean /= len(outputs)
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
# show val_loss and val_acc in progress bar but only log val_loss
results = {
'progress_bar': tqdm_dict,
'log': {'val_loss': val_loss_mean.item()}
}
return results
With multiple dataloaders, outputs
will be a list of lists. The outer list contains
one entry per dataloader, while the inner list contains the individual outputs of
each validation step for that dataloader.
def validation_end(self, outputs):
"""
Called at the end of validation to aggregate outputs
:param outputs: list of list of individual outputs of each validation step
:return:
"""
val_loss_mean = 0
val_acc_mean = 0
i = 0
for dataloader_outputs in outputs:
for output in dataloader_outputs:
val_loss_mean += output['val_loss']
val_acc_mean += output['val_acc']
i += 1
val_loss_mean /= i
val_acc_mean /= i
tqdm_dict = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
# show val_loss and val_acc in progress bar but only log val_loss
results = {
'progress_bar': tqdm_dict,
'log': {'val_loss': val_loss_mean.item()}
}
return results
test_step
# if you have one test dataloader:
def test_step(self, batch, batch_nb)
# if you have multiple test dataloaders:
def test_step(self, batch, batch_nb, dataloader_idxdx)
OPTIONAL
If you don't need to test you don't need to implement this method. In this step you'd normally generate examples or calculate anything of interest such as accuracy.
When the validation_step is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, model goes back to training mode and gradients are enabled.
The dict you return here will be available in the test_end
method.
This function is used when you execute trainer.test()
.
Params
Param | description |
---|---|
batch | The output of your dataloader. A tensor, tuple or list |
batch_nb | Integer displaying which batch this is |
dataloader_idx | Integer displaying which dataloader this is (only if multiple test datasets used) |
Return
Return | description | optional |
---|---|---|
dict | Dict or OrderedDict with metrics to display in progress bar. All keys must be tensors. | Y |
Example
# CASE 1: A single test dataset
def test_step(self, batch, batch_nb):
x, y = batch
# implement your own
out = self.forward(x)
loss = self.loss(out, y)
# calculate acc
labels_hat = torch.argmax(out, dim=1)
test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
# all optional...
# return whatever you need for the collation function test_end
output = OrderedDict({
'test_loss': loss_test,
'test_acc': torch.tensor(test_acc), # everything must be a tensor
})
# return an optional dict
return output
If you pass in multiple test datasets, test_step will have an additional argument.
# CASE 2: multiple test datasets
def test_step(self, batch, batch_nb, dataset_idx):
# dataset_idx tells you which dataset this is.
The dataset_idx
corresponds to the order of datasets returned in test_dataloader
.
test_end
def test_end(self, outputs)
If you didn't define a test_step, this won't be called.
Called at the end of the test step with the output of each test_step.
The outputs here are strictly for the progress bar. If you don't need to display anything, don't return anything.
Params
Param | description |
---|---|
outputs | List of outputs you defined in test_step, or if there are multiple dataloaders, a list containing a list of outputs for each dataloader |
Return
Return | description | optional |
---|---|---|
dict | Dict of OrderedDict with metrics to display in progress bar | Y |
Example
def test_end(self, outputs):
"""
Called at the end of test to aggregate outputs
:param outputs: list of individual outputs of each test step
:return:
"""
test_loss_mean = 0
test_acc_mean = 0
for output in outputs:
test_loss_mean += output['test_loss']
test_acc_mean += output['test_acc']
test_loss_mean /= len(outputs)
test_acc_mean /= len(outputs)
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
# show test_loss and test_acc in progress bar but only log test_loss
results = {
'progress_bar': tqdm_dict,
'log': {'test_loss': val_loss_mean.item()}
}
return results
With multiple dataloaders, outputs
will be a list of lists. The outer list contains
one entry per dataloader, while the inner list contains the individual outputs of
each validation step for that dataloader.
def test_end(self, outputs):
"""
Called at the end of test to aggregate outputs
:param outputs: list of individual outputs of each test step
:return:
"""
test_loss_mean = 0
test_acc_mean = 0
i = 0
for dataloader_outputs in outputs:
for output in dataloader_outputs:
test_loss_mean += output['test_loss']
test_acc_mean += output['test_acc']
i += 1
test_loss_mean /= i
test_acc_mean /= i
tqdm_dict = {'test_loss': test_loss_mean.item(), 'test_acc': test_acc_mean.item()}
# show test_loss and test_acc in progress bar but only log test_loss
results = {
'progress_bar': tqdm_dict,
'log': {'test_loss': val_loss_mean.item()}
}
return results
on_save_checkpoint
def on_save_checkpoint(self, checkpoint)
Called by lightning to checkpoint your model. Lightning saves the training state (current epoch, global_step, etc) and also saves the model state_dict. If you want to save anything else, use this method to add your own key-value pair.
Return
Nothing
Example
def on_save_checkpoint(self, checkpoint):
# 99% of use cases you don't need to implement this method
checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
on_load_checkpoint
def on_load_checkpoint(self, checkpoint)
Called by lightning to restore your model. Lighting auto-restores global step, epoch, etc... It also restores the model state_dict. If you saved something with on_save_checkpoint this is your chance to restore this.
Return
Nothing
Example
def on_load_checkpoint(self, checkpoint):
# 99% of the time you don't need to implement this method
self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
val_dataloader
@pl.data_loader
def val_dataloader(self)
OPTIONAL
If you don't need a validation dataset and a validation_step, you don't need to implement this method.
Called by lightning during validation loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
If you want to change the data during every epoch DON'T use the data_loader decorator.
Return
PyTorch DataLoader or list of PyTorch Dataloaders.
Example
@pl.data_loader
def val_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True)
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True
)
return loader
# can also return multiple dataloaders
@pl.data_loader
def val_dataloader(self):
return [loader_a, loader_b, ..., loader_n]
In the case where you return multiple val_dataloaders, the validation_step will have an arguement dataset_idx
which matches the order here.
test_dataloader
@pl.data_loader
def test_dataloader(self)
OPTIONAL
If you don't need a test dataset and a test_step, you don't need to implement this method.
Called by lightning during test loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed. If you want to change the data during every epoch DON'T use the data_loader decorator.
Return
PyTorch DataLoader
Example
@pl.data_loader
def test_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root='/path/to/mnist/', train=False, transform=transform, download=True)
loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.hparams.batch_size,
shuffle=True
)
return loader
add_model_specific_args
@staticmethod
def add_model_specific_args(parent_parser, root_dir)
Lightning has a list of default argparse commands. This method is your chance to add or modify commands specific to your model. The hyperparameter argument parser is available anywhere in your model by calling self.hparams.
Return
An argument parser
Example
@staticmethod
def add_model_specific_args(parent_parser, root_dir):
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
# param overwrites
# parser.set_defaults(gradient_clip_val=5.0)
# network params
parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False)
parser.add_argument('--in_features', default=28*28)
parser.add_argument('--out_features', default=10)
parser.add_argument('--hidden_dim', default=50000) # use 500 for CPU, 50000 for GPU to see speed difference
# data
parser.add_argument('--data_root', default=os.path.join(root_dir, 'mnist'), type=str)
# training params (opt)
parser.opt_list('--learning_rate', default=0.001, type=float, options=[0.0001, 0.0005, 0.001, 0.005],
tunable=False)
parser.opt_list('--batch_size', default=256, type=int, options=[32, 64, 128, 256], tunable=False)
parser.opt_list('--optimizer_name', default='adam', type=str, options=['adam'], tunable=False)
return parser