fix domain_templates (#365)

This commit is contained in:
Yiming Lin 2019-10-14 11:56:33 +01:00 committed by William Falcon
parent 19c2b8fc9e
commit b8666bf354
1 changed files with 42 additions and 12 deletions

View File

@ -9,6 +9,7 @@ tensorboard --logdir default
from argparse import ArgumentParser
import os
import numpy as np
from collections import OrderedDict
import torchvision
import torchvision.transforms as transforms
@ -21,7 +22,6 @@ import torch.nn.functional as F
import torch
import pytorch_lightning as pl
from test_tube import Experiment
class Generator(nn.Module):
@ -107,17 +107,22 @@ class GAN(pl.LightningModule):
self.generated_imgs = self.forward(z)
# log sampled images
sample_imgs = self.generated_imgs[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('generated_images', grid, 0)
# sample_imgs = self.generated_imgs[:6]
# grid = torchvision.utils.make_grid(sample_imgs)
# self.logger.experiment.add_image('generated_images', grid, 0)
# ground truth result (ie: all fake)
valid = torch.ones(imgs.size(0), 1)
# adversarial loss is binary cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
return g_loss
tqdm_dict = {'g_loss': g_loss}
output = OrderedDict({
'loss': g_loss,
'progress_bar': tqdm_dict,
'log': tqdm_dict
})
return output
# train discriminator
if optimizer_i == 1:
@ -133,8 +138,15 @@ class GAN(pl.LightningModule):
# discriminator loss is the average of these
d_loss = (real_loss + fake_loss) / 2
tqdm_dict = {'d_loss': d_loss}
output = OrderedDict({
'loss': d_loss,
'progress_bar': tqdm_dict,
'log': tqdm_dict
})
return d_loss
return output
def configure_optimizers(self):
lr = self.hparams.lr
@ -152,16 +164,34 @@ class GAN(pl.LightningModule):
dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform)
return DataLoader(dataset, batch_size=self.hparams.batch_size)
def on_epoch_end(self):
z = torch.randn(8, self.hparams.latent_dim)
# match gpu device (or keep as cpu)
if self.on_gpu:
z = z.cuda(imgs.device.index)
# log sampled images
sample_imgs = self.forward(z)
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image(f'generated_images', grid, self.current_epoch)
def main(hparams):
# save tensorboard logs
exp = Experiment(save_dir=os.getcwd())
# init model
# ------------------------
# 1 INIT LIGHTNING MODEL
# ------------------------
model = GAN(hparams)
# fit trainer on CPU
trainer = pl.Trainer(experiment=exp, max_nb_epochs=200)
# ------------------------
# 2 INIT TRAINER
# ------------------------
trainer = pl.Trainer()
# ------------------------
# 3 START TRAINING
# ------------------------
trainer.fit(model)