From b8666bf354f68d62a49812e635034dae7eee0a80 Mon Sep 17 00:00:00 2001 From: Yiming Lin Date: Mon, 14 Oct 2019 11:56:33 +0100 Subject: [PATCH] fix domain_templates (#365) --- examples/domain_templates/gan.py | 54 +++++++++++++++++++++++++------- 1 file changed, 42 insertions(+), 12 deletions(-) diff --git a/examples/domain_templates/gan.py b/examples/domain_templates/gan.py index bb1501621f..3dfb9f8c5d 100644 --- a/examples/domain_templates/gan.py +++ b/examples/domain_templates/gan.py @@ -9,6 +9,7 @@ tensorboard --logdir default from argparse import ArgumentParser import os import numpy as np +from collections import OrderedDict import torchvision import torchvision.transforms as transforms @@ -21,7 +22,6 @@ import torch.nn.functional as F import torch import pytorch_lightning as pl -from test_tube import Experiment class Generator(nn.Module): @@ -107,17 +107,22 @@ class GAN(pl.LightningModule): self.generated_imgs = self.forward(z) # log sampled images - sample_imgs = self.generated_imgs[:6] - grid = torchvision.utils.make_grid(sample_imgs) - self.logger.experiment.add_image('generated_images', grid, 0) + # sample_imgs = self.generated_imgs[:6] + # grid = torchvision.utils.make_grid(sample_imgs) + # self.logger.experiment.add_image('generated_images', grid, 0) # ground truth result (ie: all fake) valid = torch.ones(imgs.size(0), 1) # adversarial loss is binary cross-entropy g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid) - - return g_loss + tqdm_dict = {'g_loss': g_loss} + output = OrderedDict({ + 'loss': g_loss, + 'progress_bar': tqdm_dict, + 'log': tqdm_dict + }) + return output # train discriminator if optimizer_i == 1: @@ -133,8 +138,15 @@ class GAN(pl.LightningModule): # discriminator loss is the average of these d_loss = (real_loss + fake_loss) / 2 + tqdm_dict = {'d_loss': d_loss} + output = OrderedDict({ + 'loss': d_loss, + 'progress_bar': tqdm_dict, + 'log': tqdm_dict + }) - return d_loss + + return output def configure_optimizers(self): lr = self.hparams.lr @@ -152,16 +164,34 @@ class GAN(pl.LightningModule): dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform) return DataLoader(dataset, batch_size=self.hparams.batch_size) + def on_epoch_end(self): + z = torch.randn(8, self.hparams.latent_dim) + # match gpu device (or keep as cpu) + if self.on_gpu: + z = z.cuda(imgs.device.index) + + # log sampled images + sample_imgs = self.forward(z) + grid = torchvision.utils.make_grid(sample_imgs) + self.logger.experiment.add_image(f'generated_images', grid, self.current_epoch) + + def main(hparams): - # save tensorboard logs - exp = Experiment(save_dir=os.getcwd()) - # init model + # ------------------------ + # 1 INIT LIGHTNING MODEL + # ------------------------ model = GAN(hparams) - # fit trainer on CPU - trainer = pl.Trainer(experiment=exp, max_nb_epochs=200) + # ------------------------ + # 2 INIT TRAINER + # ------------------------ + trainer = pl.Trainer() + + # ------------------------ + # 3 START TRAINING + # ------------------------ trainer.fit(model)