""" To run this template just do: python gan.py After a few epochs, launch tensorboard to see the images being generated at every batch. tensorboard --logdir default """ from argparse import ArgumentParser import os import numpy as np from collections import OrderedDict import torchvision import torchvision.transforms as transforms from torchvision.datasets import MNIST from torch.utils.data import DataLoader import torch.nn as nn import torch.nn.functional as F import torch import pytorch_lightning as pl class Generator(nn.Module): def __init__(self, latent_dim, img_shape): super(Generator, self).__init__() self.img_shape = img_shape def block(in_feat, out_feat, normalize=True): layers = [nn.Linear(in_feat, out_feat)] if normalize: layers.append(nn.BatchNorm1d(out_feat, 0.8)) layers.append(nn.LeakyReLU(0.2, inplace=True)) return layers self.model = nn.Sequential( *block(latent_dim, 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh() ) def forward(self, z): img = self.model(z) img = img.view(img.size(0), *self.img_shape) return img class Discriminator(nn.Module): def __init__(self, img_shape): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Linear(int(np.prod(img_shape)), 512), nn.LeakyReLU(0.2, inplace=True), nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1), nn.Sigmoid(), ) def forward(self, img): img_flat = img.view(img.size(0), -1) validity = self.model(img_flat) return validity class GAN(pl.LightningModule): def __init__(self, hparams): super(GAN, self).__init__() self.hparams = hparams # networks mnist_shape = (1, 28, 28) self.generator = Generator(latent_dim=hparams.latent_dim, img_shape=mnist_shape) self.discriminator = Discriminator(img_shape=mnist_shape) # cache for generated images self.generated_imgs = None self.last_imgs = None def forward(self, z): return self.generator(z) def adversarial_loss(self, y_hat, y): return F.binary_cross_entropy(y_hat, y) def training_step(self, batch, batch_nb, optimizer_i): imgs, _ = batch self.last_imgs = imgs # train generator if optimizer_i == 0: # sample noise z = torch.randn(imgs.shape[0], self.hparams.latent_dim) # match gpu device (or keep as cpu) if self.on_gpu: z = z.cuda(imgs.device.index) # generate images self.generated_imgs = self.forward(z) # log sampled images # sample_imgs = self.generated_imgs[:6] # grid = torchvision.utils.make_grid(sample_imgs) # self.logger.experiment.add_image('generated_images', grid, 0) # ground truth result (ie: all fake) valid = torch.ones(imgs.size(0), 1) # adversarial loss is binary cross-entropy g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid) tqdm_dict = {'g_loss': g_loss} output = OrderedDict({ 'loss': g_loss, 'progress_bar': tqdm_dict, 'log': tqdm_dict }) return output # train discriminator if optimizer_i == 1: # Measure discriminator's ability to classify real from generated samples # how well can it label as real? valid = torch.ones(imgs.size(0), 1) real_loss = self.adversarial_loss(self.discriminator(imgs), valid) # how well can it label as fake? fake = torch.zeros(imgs.size(0), 1) fake_loss = self.adversarial_loss( self.discriminator(self.generated_imgs.detach()), fake) # discriminator loss is the average of these d_loss = (real_loss + fake_loss) / 2 tqdm_dict = {'d_loss': d_loss} output = OrderedDict({ 'loss': d_loss, 'progress_bar': tqdm_dict, 'log': tqdm_dict }) return output def configure_optimizers(self): lr = self.hparams.lr b1 = self.hparams.b1 b2 = self.hparams.b2 opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) return [opt_g, opt_d], [] @pl.data_loader def train_dataloader(self): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform) return DataLoader(dataset, batch_size=self.hparams.batch_size) def on_epoch_end(self): z = torch.randn(8, self.hparams.latent_dim) # match gpu device (or keep as cpu) if self.on_gpu: z = z.cuda(self.last_imgs.device.index) # log sampled images sample_imgs = self.forward(z) grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image(f'generated_images', grid, self.current_epoch) def main(hparams): # ------------------------ # 1 INIT LIGHTNING MODEL # ------------------------ model = GAN(hparams) # ------------------------ # 2 INIT TRAINER # ------------------------ trainer = pl.Trainer() # ------------------------ # 3 START TRAINING # ------------------------ trainer.fit(model) if __name__ == '__main__': parser = ArgumentParser() parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") hparams = parser.parse_args() main(hparams)