From 90bb7b1fbb913305e786dc067308cee5c9fe7fbc Mon Sep 17 00:00:00 2001 From: Sahin Kureta Date: Wed, 21 Oct 2020 19:07:18 +0300 Subject: [PATCH] update examples (#4233) * Removed image generation inside the training step. It was overwriting the image grid generated in `on_epoch_end`. I also made `adversarial_loss` a static method. * Incorporated Hyperparameter best practices Using ArgumentParser and hparams as defined in the Hyperparameters section of the documentation. This way we can set trainer flags (such as precision, and gpus) from the command line. * Incorporated Hyperparameter best practices Using ArgumentParser and hparams as defined in the Hyperparameters section of the documentation. This way we can set trainer flags (such as precision, and gpus) from the command line. * Split the data part into a LightningDataModule * Update pl_examples/domain_templates/generative_adversarial_net.py Co-authored-by: Jeff Yang --- .../generative_adversarial_net.py | 118 ++++++++++-------- 1 file changed, 66 insertions(+), 52 deletions(-) diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 29ea1b893b..8ebaad7965 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -8,18 +8,17 @@ tensorboard --logdir default """ import os from argparse import ArgumentParser, Namespace -from collections import OrderedDict import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F +import torch.nn.functional as F # noqa import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.datasets import MNIST -from pytorch_lightning.core import LightningModule +from pytorch_lightning.core import LightningModule, LightningDataModule from pytorch_lightning.trainer import Trainer @@ -60,7 +59,6 @@ class Discriminator(nn.Module): nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1), - nn.Sigmoid(), ) def forward(self, img): @@ -71,54 +69,49 @@ class Discriminator(nn.Module): class GAN(LightningModule): + @staticmethod + def add_argparse_args(parent_parser: ArgumentParser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") + parser.add_argument("--b1", type=float, default=0.5, + help="adam: decay of first order momentum of gradient") + parser.add_argument("--b2", type=float, default=0.999, + help="adam: decay of first order momentum of gradient") + parser.add_argument("--latent_dim", type=int, default=100, + help="dimensionality of the latent space") - def __init__(self, - latent_dim: int = 100, - lr: float = 0.0002, - b1: float = 0.5, - b2: float = 0.999, - batch_size: int = 64, **kwargs): + return parser + + def __init__(self, hparams: Namespace): super().__init__() - self.latent_dim = latent_dim - self.lr = lr - self.b1 = b1 - self.b2 = b2 - self.batch_size = batch_size + self.hparams = hparams # networks mnist_shape = (1, 28, 28) - self.generator = Generator(latent_dim=self.latent_dim, img_shape=mnist_shape) + self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=mnist_shape) self.discriminator = Discriminator(img_shape=mnist_shape) - self.validation_z = torch.randn(8, self.latent_dim) + self.validation_z = torch.randn(8, self.hparams.latent_dim) - self.example_input_array = torch.zeros(2, hparams.latent_dim) + self.example_input_array = torch.zeros(2, self.hparams.latent_dim) def forward(self, z): return self.generator(z) - def adversarial_loss(self, y_hat, y): - return F.binary_cross_entropy(y_hat, y) + @staticmethod + def adversarial_loss(y_hat, y): + return F.binary_cross_entropy_with_logits(y_hat, y) def training_step(self, batch, batch_idx, optimizer_idx): imgs, _ = batch # sample noise - z = torch.randn(imgs.shape[0], self.latent_dim) + z = torch.randn(imgs.shape[0], self.hparams.latent_dim) z = z.type_as(imgs) # train generator if optimizer_idx == 0: - - # generate images - self.generated_imgs = self(z) - - # log sampled images - sample_imgs = self.generated_imgs[:6] - grid = torchvision.utils.make_grid(sample_imgs) - self.logger.experiment.add_image('generated_images', grid, 0) - # ground truth result (ie: all fake) # put on GPU because we created this tensor inside training_loop valid = torch.ones(imgs.size(0), 1) @@ -155,20 +148,14 @@ class GAN(LightningModule): return d_loss def configure_optimizers(self): - lr = self.lr - b1 = self.b1 - b2 = self.b2 + lr = self.hparams.lr + b1 = self.hparams.b1 + b2 = self.hparams.b2 opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) return [opt_g, opt_d], [] - def train_dataloader(self): - transform = transforms.Compose([transforms.ToTensor(), - transforms.Normalize([0.5], [0.5])]) - dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform) - return DataLoader(dataset, batch_size=self.batch_size) - def on_epoch_end(self): z = self.validation_z.type_as(self.generator.model[0].weight) @@ -178,36 +165,63 @@ class GAN(LightningModule): self.logger.experiment.add_image('generated_images', grid, self.current_epoch) +class MNISTDataModule(LightningDataModule): + def __init__(self, batch_size: int = 64, data_path: str = os.getcwd(), num_workers: int = 4): + super().__init__() + self.batch_size = batch_size + self.data_path = data_path + self.num_workers = num_workers + + self.transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize([0.5], [0.5])]) + self.dims = (1, 28, 28) + + def prepare_data(self, stage=None): + # Use this method to do things that might write to disk or that need to be done only from a single GPU + # in distributed settings. Like downloading the dataset for the first time. + MNIST(self.data_path, train=True, download=True, transform=transforms.ToTensor()) + + def setup(self, stage=None): + # There are also data operations you might want to perform on every GPU, such as applying transforms + # defined explicitly in your datamodule or assigned in init. + self.mnist_train = MNIST(self.data_path, train=True, transform=self.transform) + + def train_dataloader(self): + return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers) + + def main(args: Namespace) -> None: # ------------------------ # 1 INIT LIGHTNING MODEL # ------------------------ - model = GAN(**vars(args)) + model = GAN(args) # ------------------------ # 2 INIT TRAINER # ------------------------ # If use distubuted training PyTorch recommends to use DistributedDataParallel. # See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel - trainer = Trainer() + dm = MNISTDataModule.from_argparse_args(args) + trainer = Trainer.from_argparse_args(args) # ------------------------ # 3 START TRAINING # ------------------------ - trainer.fit(model) + trainer.fit(model, dm) if __name__ == '__main__': parser = ArgumentParser() - parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") - parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") - parser.add_argument("--b1", type=float, default=0.5, - help="adam: decay of first order momentum of gradient") - parser.add_argument("--b2", type=float, default=0.999, - help="adam: decay of first order momentum of gradient") - parser.add_argument("--latent_dim", type=int, default=100, - help="dimensionality of the latent space") - hparams = parser.parse_args() + # Add program level args, if any. + # ------------------------ + # Add LightningDataLoader args + parser = MNISTDataModule.add_argparse_args(parser) + # Add model specific args + parser = GAN.add_argparse_args(parser) + # Add trainer args + parser = Trainer.add_argparse_args(parser) + # Parse all arguments + args = parser.parse_args() - main(hparams) + main(args)