update examples (#4233)
* Removed image generation inside the training step. It was overwriting the image grid generated in `on_epoch_end`. I also made `adversarial_loss` a static method. * Incorporated Hyperparameter best practices Using ArgumentParser and hparams as defined in the Hyperparameters section of the documentation. This way we can set trainer flags (such as precision, and gpus) from the command line. * Incorporated Hyperparameter best practices Using ArgumentParser and hparams as defined in the Hyperparameters section of the documentation. This way we can set trainer flags (such as precision, and gpus) from the command line. * Split the data part into a LightningDataModule * Update pl_examples/domain_templates/generative_adversarial_net.py Co-authored-by: Jeff Yang <ydcjeff@outlook.com>
This commit is contained in:
parent
8a20d6af51
commit
90bb7b1fbb
|
@ -8,18 +8,17 @@ tensorboard --logdir default
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F # noqa
|
||||||
import torchvision
|
import torchvision
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision.datasets import MNIST
|
from torchvision.datasets import MNIST
|
||||||
|
|
||||||
from pytorch_lightning.core import LightningModule
|
from pytorch_lightning.core import LightningModule, LightningDataModule
|
||||||
from pytorch_lightning.trainer import Trainer
|
from pytorch_lightning.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,7 +59,6 @@ class Discriminator(nn.Module):
|
||||||
nn.Linear(512, 256),
|
nn.Linear(512, 256),
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
nn.LeakyReLU(0.2, inplace=True),
|
||||||
nn.Linear(256, 1),
|
nn.Linear(256, 1),
|
||||||
nn.Sigmoid(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, img):
|
def forward(self, img):
|
||||||
|
@ -71,54 +69,49 @@ class Discriminator(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class GAN(LightningModule):
|
class GAN(LightningModule):
|
||||||
|
@staticmethod
|
||||||
|
def add_argparse_args(parent_parser: ArgumentParser):
|
||||||
|
parser = ArgumentParser(parents=[parent_parser], add_help=False)
|
||||||
|
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
|
||||||
|
parser.add_argument("--b1", type=float, default=0.5,
|
||||||
|
help="adam: decay of first order momentum of gradient")
|
||||||
|
parser.add_argument("--b2", type=float, default=0.999,
|
||||||
|
help="adam: decay of first order momentum of gradient")
|
||||||
|
parser.add_argument("--latent_dim", type=int, default=100,
|
||||||
|
help="dimensionality of the latent space")
|
||||||
|
|
||||||
def __init__(self,
|
return parser
|
||||||
latent_dim: int = 100,
|
|
||||||
lr: float = 0.0002,
|
def __init__(self, hparams: Namespace):
|
||||||
b1: float = 0.5,
|
|
||||||
b2: float = 0.999,
|
|
||||||
batch_size: int = 64, **kwargs):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.latent_dim = latent_dim
|
self.hparams = hparams
|
||||||
self.lr = lr
|
|
||||||
self.b1 = b1
|
|
||||||
self.b2 = b2
|
|
||||||
self.batch_size = batch_size
|
|
||||||
|
|
||||||
# networks
|
# networks
|
||||||
mnist_shape = (1, 28, 28)
|
mnist_shape = (1, 28, 28)
|
||||||
self.generator = Generator(latent_dim=self.latent_dim, img_shape=mnist_shape)
|
self.generator = Generator(latent_dim=self.hparams.latent_dim, img_shape=mnist_shape)
|
||||||
self.discriminator = Discriminator(img_shape=mnist_shape)
|
self.discriminator = Discriminator(img_shape=mnist_shape)
|
||||||
|
|
||||||
self.validation_z = torch.randn(8, self.latent_dim)
|
self.validation_z = torch.randn(8, self.hparams.latent_dim)
|
||||||
|
|
||||||
self.example_input_array = torch.zeros(2, hparams.latent_dim)
|
self.example_input_array = torch.zeros(2, self.hparams.latent_dim)
|
||||||
|
|
||||||
def forward(self, z):
|
def forward(self, z):
|
||||||
return self.generator(z)
|
return self.generator(z)
|
||||||
|
|
||||||
def adversarial_loss(self, y_hat, y):
|
@staticmethod
|
||||||
return F.binary_cross_entropy(y_hat, y)
|
def adversarial_loss(y_hat, y):
|
||||||
|
return F.binary_cross_entropy_with_logits(y_hat, y)
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||||
imgs, _ = batch
|
imgs, _ = batch
|
||||||
|
|
||||||
# sample noise
|
# sample noise
|
||||||
z = torch.randn(imgs.shape[0], self.latent_dim)
|
z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
|
||||||
z = z.type_as(imgs)
|
z = z.type_as(imgs)
|
||||||
|
|
||||||
# train generator
|
# train generator
|
||||||
if optimizer_idx == 0:
|
if optimizer_idx == 0:
|
||||||
|
|
||||||
# generate images
|
|
||||||
self.generated_imgs = self(z)
|
|
||||||
|
|
||||||
# log sampled images
|
|
||||||
sample_imgs = self.generated_imgs[:6]
|
|
||||||
grid = torchvision.utils.make_grid(sample_imgs)
|
|
||||||
self.logger.experiment.add_image('generated_images', grid, 0)
|
|
||||||
|
|
||||||
# ground truth result (ie: all fake)
|
# ground truth result (ie: all fake)
|
||||||
# put on GPU because we created this tensor inside training_loop
|
# put on GPU because we created this tensor inside training_loop
|
||||||
valid = torch.ones(imgs.size(0), 1)
|
valid = torch.ones(imgs.size(0), 1)
|
||||||
|
@ -155,20 +148,14 @@ class GAN(LightningModule):
|
||||||
return d_loss
|
return d_loss
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
lr = self.lr
|
lr = self.hparams.lr
|
||||||
b1 = self.b1
|
b1 = self.hparams.b1
|
||||||
b2 = self.b2
|
b2 = self.hparams.b2
|
||||||
|
|
||||||
opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
|
opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
|
||||||
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
|
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
|
||||||
return [opt_g, opt_d], []
|
return [opt_g, opt_d], []
|
||||||
|
|
||||||
def train_dataloader(self):
|
|
||||||
transform = transforms.Compose([transforms.ToTensor(),
|
|
||||||
transforms.Normalize([0.5], [0.5])])
|
|
||||||
dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform)
|
|
||||||
return DataLoader(dataset, batch_size=self.batch_size)
|
|
||||||
|
|
||||||
def on_epoch_end(self):
|
def on_epoch_end(self):
|
||||||
z = self.validation_z.type_as(self.generator.model[0].weight)
|
z = self.validation_z.type_as(self.generator.model[0].weight)
|
||||||
|
|
||||||
|
@ -178,36 +165,63 @@ class GAN(LightningModule):
|
||||||
self.logger.experiment.add_image('generated_images', grid, self.current_epoch)
|
self.logger.experiment.add_image('generated_images', grid, self.current_epoch)
|
||||||
|
|
||||||
|
|
||||||
|
class MNISTDataModule(LightningDataModule):
|
||||||
|
def __init__(self, batch_size: int = 64, data_path: str = os.getcwd(), num_workers: int = 4):
|
||||||
|
super().__init__()
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.data_path = data_path
|
||||||
|
self.num_workers = num_workers
|
||||||
|
|
||||||
|
self.transform = transforms.Compose([transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.5], [0.5])])
|
||||||
|
self.dims = (1, 28, 28)
|
||||||
|
|
||||||
|
def prepare_data(self, stage=None):
|
||||||
|
# Use this method to do things that might write to disk or that need to be done only from a single GPU
|
||||||
|
# in distributed settings. Like downloading the dataset for the first time.
|
||||||
|
MNIST(self.data_path, train=True, download=True, transform=transforms.ToTensor())
|
||||||
|
|
||||||
|
def setup(self, stage=None):
|
||||||
|
# There are also data operations you might want to perform on every GPU, such as applying transforms
|
||||||
|
# defined explicitly in your datamodule or assigned in init.
|
||||||
|
self.mnist_train = MNIST(self.data_path, train=True, transform=self.transform)
|
||||||
|
|
||||||
|
def train_dataloader(self):
|
||||||
|
return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)
|
||||||
|
|
||||||
|
|
||||||
def main(args: Namespace) -> None:
|
def main(args: Namespace) -> None:
|
||||||
# ------------------------
|
# ------------------------
|
||||||
# 1 INIT LIGHTNING MODEL
|
# 1 INIT LIGHTNING MODEL
|
||||||
# ------------------------
|
# ------------------------
|
||||||
model = GAN(**vars(args))
|
model = GAN(args)
|
||||||
|
|
||||||
# ------------------------
|
# ------------------------
|
||||||
# 2 INIT TRAINER
|
# 2 INIT TRAINER
|
||||||
# ------------------------
|
# ------------------------
|
||||||
# If use distubuted training PyTorch recommends to use DistributedDataParallel.
|
# If use distubuted training PyTorch recommends to use DistributedDataParallel.
|
||||||
# See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
|
# See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
|
||||||
trainer = Trainer()
|
dm = MNISTDataModule.from_argparse_args(args)
|
||||||
|
trainer = Trainer.from_argparse_args(args)
|
||||||
|
|
||||||
# ------------------------
|
# ------------------------
|
||||||
# 3 START TRAINING
|
# 3 START TRAINING
|
||||||
# ------------------------
|
# ------------------------
|
||||||
trainer.fit(model)
|
trainer.fit(model, dm)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
|
|
||||||
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
|
|
||||||
parser.add_argument("--b1", type=float, default=0.5,
|
|
||||||
help="adam: decay of first order momentum of gradient")
|
|
||||||
parser.add_argument("--b2", type=float, default=0.999,
|
|
||||||
help="adam: decay of first order momentum of gradient")
|
|
||||||
parser.add_argument("--latent_dim", type=int, default=100,
|
|
||||||
help="dimensionality of the latent space")
|
|
||||||
|
|
||||||
hparams = parser.parse_args()
|
# Add program level args, if any.
|
||||||
|
# ------------------------
|
||||||
|
# Add LightningDataLoader args
|
||||||
|
parser = MNISTDataModule.add_argparse_args(parser)
|
||||||
|
# Add model specific args
|
||||||
|
parser = GAN.add_argparse_args(parser)
|
||||||
|
# Add trainer args
|
||||||
|
parser = Trainer.add_argparse_args(parser)
|
||||||
|
# Parse all arguments
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
main(hparams)
|
main(args)
|
||||||
|
|
Loading…
Reference in New Issue