lightning/benchmarks/parity_modules.py

78 lines
2.2 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningModule
from tests.base.datasets import MNIST
class AverageDataset(Dataset):
def __init__(self, dataset_len=300, sequence_len=100):
self.dataset_len = dataset_len
self.sequence_len = sequence_len
self.input_seq = torch.randn(dataset_len, sequence_len, 10)
top, bottom = self.input_seq.chunk(2, -1)
self.output_seq = top + bottom.roll(shifts=1, dims=-1)
def __len__(self):
return self.dataset_len
def __getitem__(self, item):
return self.input_seq[item], self.output_seq[item]
class ParityModuleRNN(LightningModule):
def __init__(self):
super().__init__()
self.rnn = nn.LSTM(10, 20, batch_first=True)
self.linear_out = nn.Linear(in_features=20, out_features=5)
def forward(self, x):
seq, last = self.rnn(x)
return self.linear_out(seq)
def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self(x)
loss = F.mse_loss(y_hat, y)
return {'loss': loss}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
def train_dataloader(self):
return DataLoader(AverageDataset(), batch_size=30)
class ParityModuleMNIST(LightningModule):
def __init__(self):
super().__init__()
self.c_d1 = nn.Linear(in_features=28 * 28, out_features=128)
self.c_d1_bn = nn.BatchNorm1d(128)
self.c_d1_drop = nn.Dropout(0.3)
self.c_d2 = nn.Linear(in_features=128, out_features=10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.c_d1(x)
x = torch.tanh(x)
x = self.c_d1_bn(x)
x = self.c_d1_drop(x)
x = self.c_d2(x)
return x
def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return {'loss': loss}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
def train_dataloader(self):
return DataLoader(MNIST(train=True, download=True,),
batch_size=128)