update examples (#4233)

* Removed image generation inside the training step.

It was overwriting the image grid generated in `on_epoch_end`. I also made `adversarial_loss` a static method.

* Incorporated Hyperparameter best practices

Using ArgumentParser and hparams as defined in the Hyperparameters section of
the documentation. This way we can set trainer flags (such as precision,
and gpus) from the command line.

* Incorporated Hyperparameter best practices

Using ArgumentParser and hparams as defined in the Hyperparameters section of
the documentation. This way we can set trainer flags (such as precision,
and gpus) from the command line.

* Split the data part into a LightningDataModule

* Update pl_examples/domain_templates/generative_adversarial_net.py

Co-authored-by: Jeff Yang <ydcjeff@outlook.com>
This commit is contained in:
Sahin Kureta 2020-10-21 19:07:18 +03:00 committed by GitHub
parent 8a20d6af51
commit 90bb7b1fbb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 66 additions and 52 deletions

View File

@ -8,18 +8,17 @@ tensorboard --logdir default
""" """
import os import os
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from collections import OrderedDict
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F # noqa
import torchvision import torchvision
import torchvision.transforms as transforms import torchvision.transforms as transforms
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
from pytorch_lightning.core import LightningModule from pytorch_lightning.core import LightningModule, LightningDataModule
from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer import Trainer
@ -60,7 +59,6 @@ class Discriminator(nn.Module):
nn.Linear(512, 256), nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True), nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1), nn.Linear(256, 1),
nn.Sigmoid(),
) )
def forward(self, img): def forward(self, img):
@ -71,54 +69,49 @@ class Discriminator(nn.Module):
class GAN(LightningModule): class GAN(LightningModule):
@staticmethod
def add_argparse_args(parent_parser: ArgumentParser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5,
help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999,
help="adam: decay of first order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100,
help="dimensionality of the latent space")
def __init__(self, return parser
latent_dim: int = 100,
lr: float = 0.0002, def __init__(self, hparams: Namespace):
b1: float = 0.5,
b2: float = 0.999,
batch_size: int = 64, **kwargs):
super().__init__() super().__init__()
self.latent_dim = latent_dim self.hparams = hparams
self.lr = lr
self.b1 = b1
self.b2 = b2
self.batch_size = batch_size
# networks # networks
mnist_shape = (1, 28, 28) mnist_shape = (1, 28, 28)
self.generator = Generator(latent_dim=self.latent_dim, img_shape=mnist_shape) self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=mnist_shape)
self.discriminator = Discriminator(img_shape=mnist_shape) self.discriminator = Discriminator(img_shape=mnist_shape)
self.validation_z = torch.randn(8, self.latent_dim) self.validation_z = torch.randn(8, self.hparams.latent_dim)
self.example_input_array = torch.zeros(2, hparams.latent_dim) self.example_input_array = torch.zeros(2, self.hparams.latent_dim)
def forward(self, z): def forward(self, z):
return self.generator(z) return self.generator(z)
def adversarial_loss(self, y_hat, y): @staticmethod
return F.binary_cross_entropy(y_hat, y) def adversarial_loss(y_hat, y):
return F.binary_cross_entropy_with_logits(y_hat, y)
def training_step(self, batch, batch_idx, optimizer_idx): def training_step(self, batch, batch_idx, optimizer_idx):
imgs, _ = batch imgs, _ = batch
# sample noise # sample noise
z = torch.randn(imgs.shape[0], self.latent_dim) z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
z = z.type_as(imgs) z = z.type_as(imgs)
# train generator # train generator
if optimizer_idx == 0: if optimizer_idx == 0:
# generate images
self.generated_imgs = self(z)
# log sampled images
sample_imgs = self.generated_imgs[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('generated_images', grid, 0)
# ground truth result (ie: all fake) # ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop # put on GPU because we created this tensor inside training_loop
valid = torch.ones(imgs.size(0), 1) valid = torch.ones(imgs.size(0), 1)
@ -155,20 +148,14 @@ class GAN(LightningModule):
return d_loss return d_loss
def configure_optimizers(self): def configure_optimizers(self):
lr = self.lr lr = self.hparams.lr
b1 = self.b1 b1 = self.hparams.b1
b2 = self.b2 b2 = self.hparams.b2
opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
return [opt_g, opt_d], [] return [opt_g, opt_d], []
def train_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])])
dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform)
return DataLoader(dataset, batch_size=self.batch_size)
def on_epoch_end(self): def on_epoch_end(self):
z = self.validation_z.type_as(self.generator.model[0].weight) z = self.validation_z.type_as(self.generator.model[0].weight)
@ -178,36 +165,63 @@ class GAN(LightningModule):
self.logger.experiment.add_image('generated_images', grid, self.current_epoch) self.logger.experiment.add_image('generated_images', grid, self.current_epoch)
class MNISTDataModule(LightningDataModule):
def __init__(self, batch_size: int = 64, data_path: str = os.getcwd(), num_workers: int = 4):
super().__init__()
self.batch_size = batch_size
self.data_path = data_path
self.num_workers = num_workers
self.transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])])
self.dims = (1, 28, 28)
def prepare_data(self, stage=None):
# Use this method to do things that might write to disk or that need to be done only from a single GPU
# in distributed settings. Like downloading the dataset for the first time.
MNIST(self.data_path, train=True, download=True, transform=transforms.ToTensor())
def setup(self, stage=None):
# There are also data operations you might want to perform on every GPU, such as applying transforms
# defined explicitly in your datamodule or assigned in init.
self.mnist_train = MNIST(self.data_path, train=True, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)
def main(args: Namespace) -> None: def main(args: Namespace) -> None:
# ------------------------ # ------------------------
# 1 INIT LIGHTNING MODEL # 1 INIT LIGHTNING MODEL
# ------------------------ # ------------------------
model = GAN(**vars(args)) model = GAN(args)
# ------------------------ # ------------------------
# 2 INIT TRAINER # 2 INIT TRAINER
# ------------------------ # ------------------------
# If use distubuted training PyTorch recommends to use DistributedDataParallel. # If use distubuted training PyTorch recommends to use DistributedDataParallel.
# See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel # See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
trainer = Trainer() dm = MNISTDataModule.from_argparse_args(args)
trainer = Trainer.from_argparse_args(args)
# ------------------------ # ------------------------
# 3 START TRAINING # 3 START TRAINING
# ------------------------ # ------------------------
trainer.fit(model) trainer.fit(model, dm)
if __name__ == '__main__': if __name__ == '__main__':
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5,
help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999,
help="adam: decay of first order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100,
help="dimensionality of the latent space")
hparams = parser.parse_args() # Add program level args, if any.
# ------------------------
# Add LightningDataLoader args
parser = MNISTDataModule.add_argparse_args(parser)
# Add model specific args
parser = GAN.add_argparse_args(parser)
# Add trainer args
parser = Trainer.add_argparse_args(parser)
# Parse all arguments
args = parser.parse_args()
main(hparams) main(args)