diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 29ea1b893b..8ebaad7965 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -8,18 +8,17 @@ tensorboard --logdir default """ import os from argparse import ArgumentParser, Namespace -from collections import OrderedDict import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F +import torch.nn.functional as F # noqa import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.datasets import MNIST -from pytorch_lightning.core import LightningModule +from pytorch_lightning.core import LightningModule, LightningDataModule from pytorch_lightning.trainer import Trainer @@ -60,7 +59,6 @@ class Discriminator(nn.Module): nn.Linear(512, 256), nn.LeakyReLU(0.2, inplace=True), nn.Linear(256, 1), - nn.Sigmoid(), ) def forward(self, img): @@ -71,54 +69,49 @@ class Discriminator(nn.Module): class GAN(LightningModule): + @staticmethod + def add_argparse_args(parent_parser: ArgumentParser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") + parser.add_argument("--b1", type=float, default=0.5, + help="adam: decay of first order momentum of gradient") + parser.add_argument("--b2", type=float, default=0.999, + help="adam: decay of first order momentum of gradient") + parser.add_argument("--latent_dim", type=int, default=100, + help="dimensionality of the latent space") - def __init__(self, - latent_dim: int = 100, - lr: float = 0.0002, - b1: float = 0.5, - b2: float = 0.999, - batch_size: int = 64, **kwargs): + return parser + + def __init__(self, hparams: Namespace): super().__init__() - self.latent_dim = latent_dim - self.lr = lr - self.b1 = b1 - self.b2 = b2 - self.batch_size = batch_size + self.hparams = hparams # networks mnist_shape = (1, 28, 28) - self.generator = Generator(latent_dim=self.latent_dim, img_shape=mnist_shape) + self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=mnist_shape) self.discriminator = Discriminator(img_shape=mnist_shape) - self.validation_z = torch.randn(8, self.latent_dim) + self.validation_z = torch.randn(8, self.hparams.latent_dim) - self.example_input_array = torch.zeros(2, hparams.latent_dim) + self.example_input_array = torch.zeros(2, self.hparams.latent_dim) def forward(self, z): return self.generator(z) - def adversarial_loss(self, y_hat, y): - return F.binary_cross_entropy(y_hat, y) + @staticmethod + def adversarial_loss(y_hat, y): + return F.binary_cross_entropy_with_logits(y_hat, y) def training_step(self, batch, batch_idx, optimizer_idx): imgs, _ = batch # sample noise - z = torch.randn(imgs.shape[0], self.latent_dim) + z = torch.randn(imgs.shape[0], self.hparams.latent_dim) z = z.type_as(imgs) # train generator if optimizer_idx == 0: - - # generate images - self.generated_imgs = self(z) - - # log sampled images - sample_imgs = self.generated_imgs[:6] - grid = torchvision.utils.make_grid(sample_imgs) - self.logger.experiment.add_image('generated_images', grid, 0) - # ground truth result (ie: all fake) # put on GPU because we created this tensor inside training_loop valid = torch.ones(imgs.size(0), 1) @@ -155,20 +148,14 @@ class GAN(LightningModule): return d_loss def configure_optimizers(self): - lr = self.lr - b1 = self.b1 - b2 = self.b2 + lr = self.hparams.lr + b1 = self.hparams.b1 + b2 = self.hparams.b2 opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) return [opt_g, opt_d], [] - def train_dataloader(self): - transform = transforms.Compose([transforms.ToTensor(), - transforms.Normalize([0.5], [0.5])]) - dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform) - return DataLoader(dataset, batch_size=self.batch_size) - def on_epoch_end(self): z = self.validation_z.type_as(self.generator.model[0].weight) @@ -178,36 +165,63 @@ class GAN(LightningModule): self.logger.experiment.add_image('generated_images', grid, self.current_epoch) +class MNISTDataModule(LightningDataModule): + def __init__(self, batch_size: int = 64, data_path: str = os.getcwd(), num_workers: int = 4): + super().__init__() + self.batch_size = batch_size + self.data_path = data_path + self.num_workers = num_workers + + self.transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize([0.5], [0.5])]) + self.dims = (1, 28, 28) + + def prepare_data(self, stage=None): + # Use this method to do things that might write to disk or that need to be done only from a single GPU + # in distributed settings. Like downloading the dataset for the first time. + MNIST(self.data_path, train=True, download=True, transform=transforms.ToTensor()) + + def setup(self, stage=None): + # There are also data operations you might want to perform on every GPU, such as applying transforms + # defined explicitly in your datamodule or assigned in init. + self.mnist_train = MNIST(self.data_path, train=True, transform=self.transform) + + def train_dataloader(self): + return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers) + + def main(args: Namespace) -> None: # ------------------------ # 1 INIT LIGHTNING MODEL # ------------------------ - model = GAN(**vars(args)) + model = GAN(args) # ------------------------ # 2 INIT TRAINER # ------------------------ # If use distubuted training PyTorch recommends to use DistributedDataParallel. # See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel - trainer = Trainer() + dm = MNISTDataModule.from_argparse_args(args) + trainer = Trainer.from_argparse_args(args) # ------------------------ # 3 START TRAINING # ------------------------ - trainer.fit(model) + trainer.fit(model, dm) if __name__ == '__main__': parser = ArgumentParser() - parser.add_argument("--batch_size", type=int, default=64, help="size of the batches") - parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") - parser.add_argument("--b1", type=float, default=0.5, - help="adam: decay of first order momentum of gradient") - parser.add_argument("--b2", type=float, default=0.999, - help="adam: decay of first order momentum of gradient") - parser.add_argument("--latent_dim", type=int, default=100, - help="dimensionality of the latent space") - hparams = parser.parse_args() + # Add program level args, if any. + # ------------------------ + # Add LightningDataLoader args + parser = MNISTDataModule.add_argparse_args(parser) + # Add model specific args + parser = GAN.add_argparse_args(parser) + # Add trainer args + parser = Trainer.add_argparse_args(parser) + # Parse all arguments + args = parser.parse_args() - main(hparams) + main(args)