import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from pytorch_lightning import LightningModule from tests.base.datasets import MNIST class AverageDataset(Dataset): def __init__(self, dataset_len=300, sequence_len=100): self.dataset_len = dataset_len self.sequence_len = sequence_len self.input_seq = torch.randn(dataset_len, sequence_len, 10) top, bottom = self.input_seq.chunk(2, -1) self.output_seq = top + bottom.roll(shifts=1, dims=-1) def __len__(self): return self.dataset_len def __getitem__(self, item): return self.input_seq[item], self.output_seq[item] class ParityModuleRNN(LightningModule): def __init__(self): super().__init__() self.rnn = nn.LSTM(10, 20, batch_first=True) self.linear_out = nn.Linear(in_features=20, out_features=5) def forward(self, x): seq, last = self.rnn(x) return self.linear_out(seq) def training_step(self, batch, batch_nb): x, y = batch y_hat = self(x) loss = F.mse_loss(y_hat, y) return {'loss': loss} def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.02) def train_dataloader(self): return DataLoader(AverageDataset(), batch_size=30) class ParityModuleMNIST(LightningModule): def __init__(self): super().__init__() self.c_d1 = nn.Linear(in_features=28 * 28, out_features=128) self.c_d1_bn = nn.BatchNorm1d(128) self.c_d1_drop = nn.Dropout(0.3) self.c_d2 = nn.Linear(in_features=128, out_features=10) def forward(self, x): x = x.view(x.size(0), -1) x = self.c_d1(x) x = torch.tanh(x) x = self.c_d1_bn(x) x = self.c_d1_drop(x) x = self.c_d2(x) return x def training_step(self, batch, batch_nb): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) return {'loss': loss} def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.02) def train_dataloader(self): return DataLoader(MNIST(train=True, download=True,), batch_size=128)