2019-08-04 18:19:23 +00:00
|
|
|
"""
|
|
|
|
Example template for defining a system
|
|
|
|
"""
|
2019-06-27 15:04:02 +00:00
|
|
|
import os
|
|
|
|
from collections import OrderedDict
|
|
|
|
import torch.nn as nn
|
|
|
|
from torchvision.datasets import MNIST
|
|
|
|
import torchvision.transforms as transforms
|
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from test_tube import HyperOptArgumentParser
|
|
|
|
from torch import optim
|
2019-07-08 22:02:41 +00:00
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
from torch.utils.data.distributed import DistributedSampler
|
2019-06-27 15:04:02 +00:00
|
|
|
|
2019-08-07 06:02:55 +00:00
|
|
|
import pytorch_lightning as pl
|
2019-06-27 15:04:02 +00:00
|
|
|
from pytorch_lightning.root_module.root_module import LightningModule
|
|
|
|
|
|
|
|
|
|
|
|
class LightningTemplateModel(LightningModule):
|
|
|
|
"""
|
|
|
|
Sample model to show how to define a template
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, hparams):
|
|
|
|
"""
|
|
|
|
Pass in parsed HyperOptArgumentParser to the model
|
|
|
|
:param hparams:
|
|
|
|
"""
|
|
|
|
# init superclass
|
2019-07-25 16:04:20 +00:00
|
|
|
super(LightningTemplateModel, self).__init__()
|
2019-07-25 16:09:09 +00:00
|
|
|
self.hparams = hparams
|
2019-06-27 15:04:02 +00:00
|
|
|
|
|
|
|
self.batch_size = hparams.batch_size
|
|
|
|
|
2019-07-24 20:27:16 +00:00
|
|
|
# if you specify an example input, the summary will show input/output for each layer
|
|
|
|
self.example_input_array = torch.rand(5, 28 * 28)
|
2019-07-24 20:20:42 +00:00
|
|
|
|
2019-06-27 15:04:02 +00:00
|
|
|
# build model
|
|
|
|
self.__build_model()
|
|
|
|
|
|
|
|
# ---------------------
|
|
|
|
# MODEL SETUP
|
|
|
|
# ---------------------
|
|
|
|
def __build_model(self):
|
|
|
|
"""
|
|
|
|
Layout model
|
|
|
|
:return:
|
|
|
|
"""
|
2019-08-06 10:08:31 +00:00
|
|
|
self.c_d1 = nn.Linear(in_features=self.hparams.in_features,
|
|
|
|
out_features=self.hparams.hidden_dim)
|
2019-06-27 15:04:02 +00:00
|
|
|
self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim)
|
|
|
|
self.c_d1_drop = nn.Dropout(self.hparams.drop_prob)
|
|
|
|
|
2019-08-06 10:08:31 +00:00
|
|
|
self.c_d2 = nn.Linear(in_features=self.hparams.hidden_dim,
|
|
|
|
out_features=self.hparams.out_features)
|
2019-06-27 15:04:02 +00:00
|
|
|
|
|
|
|
# ---------------------
|
|
|
|
# TRAINING
|
|
|
|
# ---------------------
|
|
|
|
def forward(self, x):
|
|
|
|
"""
|
|
|
|
No special modification required for lightning, define as you normally would
|
|
|
|
:param x:
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
|
|
|
|
x = self.c_d1(x)
|
|
|
|
x = torch.tanh(x)
|
|
|
|
x = self.c_d1_bn(x)
|
|
|
|
x = self.c_d1_drop(x)
|
|
|
|
|
|
|
|
x = self.c_d2(x)
|
|
|
|
logits = F.log_softmax(x, dim=1)
|
|
|
|
|
|
|
|
return logits
|
|
|
|
|
|
|
|
def loss(self, labels, logits):
|
|
|
|
nll = F.nll_loss(logits, labels)
|
|
|
|
return nll
|
|
|
|
|
|
|
|
def training_step(self, data_batch, batch_i):
|
|
|
|
"""
|
|
|
|
Lightning calls this inside the training loop
|
|
|
|
:param data_batch:
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
# forward pass
|
|
|
|
x, y = data_batch
|
|
|
|
x = x.view(x.size(0), -1)
|
2019-07-24 17:55:20 +00:00
|
|
|
|
2019-06-27 15:04:02 +00:00
|
|
|
y_hat = self.forward(x)
|
|
|
|
|
|
|
|
# calculate loss
|
|
|
|
loss_val = self.loss(y, y_hat)
|
|
|
|
|
2019-07-24 18:04:17 +00:00
|
|
|
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
|
|
|
|
if self.trainer.use_dp:
|
|
|
|
loss_val = loss_val.unsqueeze(0)
|
|
|
|
|
2019-06-27 15:04:02 +00:00
|
|
|
output = OrderedDict({
|
2019-07-24 18:04:17 +00:00
|
|
|
'loss': loss_val
|
2019-06-27 15:04:02 +00:00
|
|
|
})
|
2019-07-18 16:11:59 +00:00
|
|
|
|
|
|
|
# can also return just a scalar instead of a dict (return loss_val)
|
|
|
|
return output
|
2019-06-27 15:04:02 +00:00
|
|
|
|
2019-08-13 15:37:37 +00:00
|
|
|
def validation_step(self, data_batch, batch_i):
|
2019-06-27 15:04:02 +00:00
|
|
|
"""
|
|
|
|
Lightning calls this inside the validation loop
|
|
|
|
:param data_batch:
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
x, y = data_batch
|
|
|
|
x = x.view(x.size(0), -1)
|
|
|
|
y_hat = self.forward(x)
|
|
|
|
|
|
|
|
loss_val = self.loss(y, y_hat)
|
|
|
|
|
|
|
|
# acc
|
|
|
|
labels_hat = torch.argmax(y_hat, dim=1)
|
|
|
|
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
2019-07-24 13:29:46 +00:00
|
|
|
val_acc = torch.tensor(val_acc)
|
|
|
|
|
|
|
|
if self.on_gpu:
|
|
|
|
val_acc = val_acc.cuda(loss_val.device.index)
|
2019-06-27 15:04:02 +00:00
|
|
|
|
2019-07-24 18:04:17 +00:00
|
|
|
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
|
|
|
|
if self.trainer.use_dp:
|
|
|
|
loss_val = loss_val.unsqueeze(0)
|
|
|
|
val_acc = val_acc.unsqueeze(0)
|
|
|
|
|
2019-06-27 15:04:02 +00:00
|
|
|
output = OrderedDict({
|
2019-07-24 18:04:17 +00:00
|
|
|
'val_loss': loss_val,
|
|
|
|
'val_acc': val_acc,
|
2019-06-27 15:04:02 +00:00
|
|
|
})
|
2019-07-18 16:11:59 +00:00
|
|
|
|
|
|
|
# can also return just a scalar instead of a dict (return loss_val)
|
|
|
|
return output
|
2019-06-27 15:04:02 +00:00
|
|
|
|
|
|
|
def validation_end(self, outputs):
|
|
|
|
"""
|
|
|
|
Called at the end of validation to aggregate outputs
|
|
|
|
:param outputs: list of individual outputs of each validation step
|
|
|
|
:return:
|
|
|
|
"""
|
2019-07-18 16:11:59 +00:00
|
|
|
# if returned a scalar from validation_step, outputs is a list of tensor scalars
|
|
|
|
# we return just the average in this case (if we want)
|
|
|
|
# return torch.stack(outputs).mean()
|
2019-07-18 16:09:25 +00:00
|
|
|
|
2019-06-27 15:04:02 +00:00
|
|
|
val_loss_mean = 0
|
|
|
|
val_acc_mean = 0
|
|
|
|
for output in outputs:
|
2019-08-08 16:06:29 +00:00
|
|
|
val_loss = output['val_loss']
|
|
|
|
|
|
|
|
# reduce manually when using dp
|
|
|
|
if self.trainer.use_dp:
|
|
|
|
val_loss = torch.mean(val_loss)
|
|
|
|
val_loss_mean += val_loss
|
|
|
|
|
|
|
|
# reduce manually when using dp
|
|
|
|
val_acc = output['val_acc']
|
|
|
|
if self.trainer.use_dp:
|
2019-08-17 15:11:07 +00:00
|
|
|
val_acc = torch.mean(val_acc)
|
2019-08-08 16:06:29 +00:00
|
|
|
|
2019-08-17 15:11:07 +00:00
|
|
|
val_acc_mean += val_acc
|
2019-06-27 15:04:02 +00:00
|
|
|
|
|
|
|
val_loss_mean /= len(outputs)
|
|
|
|
val_acc_mean /= len(outputs)
|
2019-08-08 16:06:29 +00:00
|
|
|
tqdm_dic = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean}
|
2019-06-27 15:04:02 +00:00
|
|
|
return tqdm_dic
|
|
|
|
|
|
|
|
# ---------------------
|
|
|
|
# TRAINING SETUP
|
|
|
|
# ---------------------
|
|
|
|
def configure_optimizers(self):
|
|
|
|
"""
|
|
|
|
return whatever optimizers we want here
|
|
|
|
:return: list of optimizers
|
|
|
|
"""
|
2019-06-28 17:53:00 +00:00
|
|
|
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
2019-07-24 05:12:45 +00:00
|
|
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
|
|
|
|
return [optimizer], [scheduler]
|
2019-06-27 15:04:02 +00:00
|
|
|
|
|
|
|
def __dataloader(self, train):
|
|
|
|
# init data generators
|
2019-08-06 10:08:31 +00:00
|
|
|
transform = transforms.Compose([transforms.ToTensor(),
|
|
|
|
transforms.Normalize((0.5,), (1.0,))])
|
|
|
|
dataset = MNIST(root=self.hparams.data_root, train=train,
|
|
|
|
transform=transform, download=True)
|
2019-06-27 15:04:02 +00:00
|
|
|
|
2019-08-23 06:42:40 +00:00
|
|
|
# when using multi-node (ddp) we need to add the datasampler
|
2019-07-08 23:42:53 +00:00
|
|
|
train_sampler = None
|
|
|
|
batch_size = self.hparams.batch_size
|
|
|
|
|
2019-08-24 01:23:27 +00:00
|
|
|
if self.use_ddp:
|
2019-08-23 06:42:40 +00:00
|
|
|
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
|
|
|
|
batch_size = batch_size // self.trainer.world_size # scale batch size
|
2019-07-08 22:02:41 +00:00
|
|
|
|
2019-07-08 22:59:16 +00:00
|
|
|
should_shuffle = train_sampler is None
|
2019-07-08 22:02:41 +00:00
|
|
|
loader = DataLoader(
|
2019-06-27 15:04:02 +00:00
|
|
|
dataset=dataset,
|
2019-07-08 23:42:53 +00:00
|
|
|
batch_size=batch_size,
|
2019-07-08 22:59:16 +00:00
|
|
|
shuffle=should_shuffle,
|
2019-07-08 22:02:41 +00:00
|
|
|
sampler=train_sampler
|
2019-06-27 15:04:02 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
return loader
|
|
|
|
|
2019-08-07 06:02:55 +00:00
|
|
|
@pl.data_loader
|
2019-06-27 15:04:02 +00:00
|
|
|
def tng_dataloader(self):
|
2019-07-25 15:19:20 +00:00
|
|
|
print('tng data loader called')
|
2019-07-25 14:56:03 +00:00
|
|
|
return self.__dataloader(train=True)
|
|
|
|
|
2019-08-07 06:02:55 +00:00
|
|
|
@pl.data_loader
|
2019-06-27 15:04:02 +00:00
|
|
|
def val_dataloader(self):
|
2019-07-25 15:19:20 +00:00
|
|
|
print('val data loader called')
|
2019-08-13 15:37:37 +00:00
|
|
|
return self.__dataloader(train=False)
|
2019-07-25 14:56:03 +00:00
|
|
|
|
2019-08-07 06:02:55 +00:00
|
|
|
@pl.data_loader
|
2019-06-27 15:04:02 +00:00
|
|
|
def test_dataloader(self):
|
2019-07-25 15:19:20 +00:00
|
|
|
print('test data loader called')
|
2019-07-25 14:56:03 +00:00
|
|
|
return self.__dataloader(train=False)
|
2019-06-27 15:04:02 +00:00
|
|
|
|
|
|
|
@staticmethod
|
2019-08-05 21:57:39 +00:00
|
|
|
def add_model_specific_args(parent_parser, root_dir): # pragma: no cover
|
2019-06-27 15:04:02 +00:00
|
|
|
"""
|
|
|
|
Parameters you define here will be available to your model through self.hparams
|
|
|
|
:param parent_parser:
|
|
|
|
:param root_dir:
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
|
|
|
|
|
|
|
|
# param overwrites
|
|
|
|
# parser.set_defaults(gradient_clip=5.0)
|
|
|
|
|
|
|
|
# network params
|
2019-08-05 21:57:39 +00:00
|
|
|
parser.add_argument('--in_features', default=28 * 28, type=int)
|
2019-07-08 14:57:34 +00:00
|
|
|
parser.add_argument('--out_features', default=10, type=int)
|
2019-08-05 21:57:39 +00:00
|
|
|
# use 500 for CPU, 50000 for GPU to see speed difference
|
|
|
|
parser.add_argument('--hidden_dim', default=50000, type=int)
|
2019-09-08 22:17:33 +00:00
|
|
|
parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=True)
|
|
|
|
parser.opt_list('--learning_rate', default=0.001 * 8, type=float,
|
|
|
|
options=[0.0001, 0.0005, 0.001],
|
|
|
|
tunable=True)
|
2019-06-27 15:04:02 +00:00
|
|
|
|
|
|
|
# data
|
|
|
|
parser.add_argument('--data_root', default=os.path.join(root_dir, 'mnist'), type=str)
|
|
|
|
|
|
|
|
# training params (opt)
|
2019-08-06 10:08:31 +00:00
|
|
|
parser.opt_list('--optimizer_name', default='adam', type=str,
|
|
|
|
options=['adam'], tunable=False)
|
|
|
|
|
|
|
|
# if using 2 nodes with 4 gpus each the batch size here
|
|
|
|
# (256) will be 256 / (2*8) = 16 per gpu
|
|
|
|
parser.opt_list('--batch_size', default=256 * 8, type=int,
|
|
|
|
options=[32, 64, 128, 256], tunable=False,
|
|
|
|
help='batch size will be divided over all gpus being used across all nodes')
|
2019-06-27 15:04:02 +00:00
|
|
|
return parser
|