78 lines
2.2 KiB
Python
78 lines
2.2 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from torch.utils.data import Dataset, DataLoader
|
||
|
|
||
|
from pytorch_lightning import LightningModule
|
||
|
from tests.base.datasets import MNIST
|
||
|
|
||
|
|
||
|
class AverageDataset(Dataset):
|
||
|
def __init__(self, dataset_len=300, sequence_len=100):
|
||
|
self.dataset_len = dataset_len
|
||
|
self.sequence_len = sequence_len
|
||
|
self.input_seq = torch.randn(dataset_len, sequence_len, 10)
|
||
|
top, bottom = self.input_seq.chunk(2, -1)
|
||
|
self.output_seq = top + bottom.roll(shifts=1, dims=-1)
|
||
|
|
||
|
def __len__(self):
|
||
|
return self.dataset_len
|
||
|
|
||
|
def __getitem__(self, item):
|
||
|
return self.input_seq[item], self.output_seq[item]
|
||
|
|
||
|
|
||
|
class ParityModuleRNN(LightningModule):
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self.rnn = nn.LSTM(10, 20, batch_first=True)
|
||
|
self.linear_out = nn.Linear(in_features=20, out_features=5)
|
||
|
|
||
|
def forward(self, x):
|
||
|
seq, last = self.rnn(x)
|
||
|
return self.linear_out(seq)
|
||
|
|
||
|
def training_step(self, batch, batch_nb):
|
||
|
x, y = batch
|
||
|
y_hat = self(x)
|
||
|
loss = F.mse_loss(y_hat, y)
|
||
|
return {'loss': loss}
|
||
|
|
||
|
def configure_optimizers(self):
|
||
|
return torch.optim.Adam(self.parameters(), lr=0.02)
|
||
|
|
||
|
def train_dataloader(self):
|
||
|
return DataLoader(AverageDataset(), batch_size=30)
|
||
|
|
||
|
|
||
|
class ParityModuleMNIST(LightningModule):
|
||
|
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self.c_d1 = nn.Linear(in_features=28 * 28, out_features=128)
|
||
|
self.c_d1_bn = nn.BatchNorm1d(128)
|
||
|
self.c_d1_drop = nn.Dropout(0.3)
|
||
|
self.c_d2 = nn.Linear(in_features=128, out_features=10)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = x.view(x.size(0), -1)
|
||
|
x = self.c_d1(x)
|
||
|
x = torch.tanh(x)
|
||
|
x = self.c_d1_bn(x)
|
||
|
x = self.c_d1_drop(x)
|
||
|
x = self.c_d2(x)
|
||
|
return x
|
||
|
|
||
|
def training_step(self, batch, batch_nb):
|
||
|
x, y = batch
|
||
|
y_hat = self(x)
|
||
|
loss = F.cross_entropy(y_hat, y)
|
||
|
return {'loss': loss}
|
||
|
|
||
|
def configure_optimizers(self):
|
||
|
return torch.optim.Adam(self.parameters(), lr=0.02)
|
||
|
|
||
|
def train_dataloader(self):
|
||
|
return DataLoader(MNIST(train=True, download=True,),
|
||
|
batch_size=128)
|