lightning/pl_examples/basic_examples/lightning_module_template.py

283 lines
9.1 KiB
Python

"""
Example template for defining a system.
"""
import os
from argparse import ArgumentParser
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch import optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from pytorch_lightning import _logger as log
from pytorch_lightning.core import LightningModule
class LightningTemplateModel(LightningModule):
"""
Sample model to show how to define a template.
Example:
>>> # define simple Net for MNIST dataset
>>> params = dict(
... drop_prob=0.2,
... batch_size=2,
... in_features=28 * 28,
... learning_rate=0.001 * 8,
... optimizer_name='adam',
... data_root='./datasets',
... out_features=10,
... hidden_dim=1000,
... )
>>> from argparse import Namespace
>>> hparams = Namespace(**params)
>>> model = LightningTemplateModel(hparams)
"""
def __init__(self, hparams):
"""
Pass in hyperparameters as a `argparse.Namespace` or a `dict` to the model.
"""
# init superclass
super().__init__()
self.hparams = hparams
self.batch_size = hparams.batch_size
# if you specify an example input, the summary will show input/output for each layer
self.example_input_array = torch.rand(5, 28 * 28)
# build model
self.__build_model()
# ---------------------
# MODEL SETUP
# ---------------------
def __build_model(self):
"""
Layout the model.
"""
self.c_d1 = nn.Linear(in_features=self.hparams.in_features,
out_features=self.hparams.hidden_dim)
self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim)
self.c_d1_drop = nn.Dropout(self.hparams.drop_prob)
self.c_d2 = nn.Linear(in_features=self.hparams.hidden_dim,
out_features=self.hparams.out_features)
# ---------------------
# TRAINING
# ---------------------
def forward(self, x):
"""
No special modification required for Lightning, define it as you normally would
in the `nn.Module` in vanilla PyTorch.
"""
x = self.c_d1(x)
x = torch.tanh(x)
x = self.c_d1_bn(x)
x = self.c_d1_drop(x)
x = self.c_d2(x)
logits = F.log_softmax(x, dim=1)
return logits
def loss(self, labels, logits):
nll = F.nll_loss(logits, labels)
return nll
def training_step(self, batch, batch_idx):
"""
Lightning calls this inside the training loop with the data from the training dataloader
passed in as `batch`.
"""
# forward pass
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)
# calculate loss
loss_val = self.loss(y, y_hat)
tqdm_dict = {'train_loss': loss_val}
output = OrderedDict({
'loss': loss_val,
'progress_bar': tqdm_dict,
'log': tqdm_dict
})
# can also return just a scalar instead of a dict (return loss_val)
return output
def validation_step(self, batch, batch_idx):
"""
Lightning calls this inside the validation loop with the data from the validation dataloader
passed in as `batch`.
"""
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x)
loss_val = self.loss(y, y_hat)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
val_acc = torch.tensor(val_acc)
if self.on_gpu:
val_acc = val_acc.cuda(loss_val.device.index)
output = OrderedDict({
'val_loss': loss_val,
'val_acc': val_acc,
})
# can also return just a scalar instead of a dict (return loss_val)
return output
def validation_epoch_end(self, outputs):
"""
Called at the end of validation to aggregate outputs.
:param outputs: list of individual outputs of each validation step.
"""
# if returned a scalar from validation_step, outputs is a list of tensor scalars
# we return just the average in this case (if we want)
# return torch.stack(outputs).mean()
val_loss_mean = 0
val_acc_mean = 0
for output in outputs:
val_loss = output['val_loss']
# reduce manually when using dp
if self.trainer.use_dp or self.trainer.use_ddp2:
val_loss = torch.mean(val_loss)
val_loss_mean += val_loss
# reduce manually when using dp
val_acc = output['val_acc']
if self.trainer.use_dp or self.trainer.use_ddp2:
val_acc = torch.mean(val_acc)
val_acc_mean += val_acc
val_loss_mean /= len(outputs)
val_acc_mean /= len(outputs)
tqdm_dict = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean}
result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 'val_loss': val_loss_mean}
return result
# ---------------------
# TRAINING SETUP
# ---------------------
def configure_optimizers(self):
"""
Return whatever optimizers and learning rate schedulers you want here.
At least one optimizer is required.
"""
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
return [optimizer], [scheduler]
def __dataloader(self, train):
# this is neede when you want some info about dataset before binding to trainer
self.prepare_data()
# init data generators
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root=self.hparams.data_root, train=train,
transform=transform, download=False)
# when using multi-node (ddp) we need to add the datasampler
batch_size = self.hparams.batch_size
loader = DataLoader(
dataset=dataset,
batch_size=batch_size,
num_workers=0
)
return loader
def prepare_data(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
_ = MNIST(root=self.hparams.data_root, train=True,
transform=transform, download=True)
def train_dataloader(self):
log.info('Training data loader called.')
return self.__dataloader(train=True)
def val_dataloader(self):
log.info('Validation data loader called.')
return self.__dataloader(train=False)
def test_dataloader(self):
log.info('Test data loader called.')
return self.__dataloader(train=False)
def test_step(self, batch, batch_idx):
"""
Lightning calls this during testing, similar to `validation_step`,
with the data from the test dataloader passed in as `batch`.
"""
output = self.validation_step(batch, batch_idx)
# Rename output keys
output['test_loss'] = output.pop('val_loss')
output['test_acc'] = output.pop('val_acc')
return output
def test_epoch_end(self, outputs):
"""
Called at the end of test to aggregate outputs, similar to `validation_epoch_end`.
:param outputs: list of individual outputs of each test step
"""
results = self.validation_step_end(outputs)
# rename some keys
results['progress_bar'].update({
'test_loss': results['progress_bar'].pop('val_loss'),
'test_acc': results['progress_bar'].pop('val_acc'),
})
results['log'] = results['progress_bar']
results['test_loss'] = results.pop('val_loss')
return results
@staticmethod
def add_model_specific_args(parent_parser, root_dir): # pragma: no-cover
"""
Parameters you define here will be available to your model through `self.hparams`.
"""
parser = ArgumentParser(parents=[parent_parser])
# param overwrites
# parser.set_defaults(gradient_clip_val=5.0)
# network params
parser.add_argument('--in_features', default=28 * 28, type=int)
parser.add_argument('--out_features', default=10, type=int)
# use 500 for CPU, 50000 for GPU to see speed difference
parser.add_argument('--hidden_dim', default=50000, type=int)
parser.add_argument('--drop_prob', default=0.2, type=float)
parser.add_argument('--learning_rate', default=0.001, type=float)
# data
parser.add_argument('--data_root', default=os.path.join(root_dir, 'mnist'), type=str)
# training params (opt)
parser.add_argument('--epochs', default=20, type=int)
parser.add_argument('--optimizer_name', default='adam', type=str)
parser.add_argument('--batch_size', default=64, type=int)
return parser