fix domain_templates (#365)
This commit is contained in:
parent
19c2b8fc9e
commit
b8666bf354
|
@ -9,6 +9,7 @@ tensorboard --logdir default
|
|||
from argparse import ArgumentParser
|
||||
import os
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
@ -21,7 +22,6 @@ import torch.nn.functional as F
|
|||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from test_tube import Experiment
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
|
@ -107,17 +107,22 @@ class GAN(pl.LightningModule):
|
|||
self.generated_imgs = self.forward(z)
|
||||
|
||||
# log sampled images
|
||||
sample_imgs = self.generated_imgs[:6]
|
||||
grid = torchvision.utils.make_grid(sample_imgs)
|
||||
self.logger.experiment.add_image('generated_images', grid, 0)
|
||||
# sample_imgs = self.generated_imgs[:6]
|
||||
# grid = torchvision.utils.make_grid(sample_imgs)
|
||||
# self.logger.experiment.add_image('generated_images', grid, 0)
|
||||
|
||||
# ground truth result (ie: all fake)
|
||||
valid = torch.ones(imgs.size(0), 1)
|
||||
|
||||
# adversarial loss is binary cross-entropy
|
||||
g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
|
||||
|
||||
return g_loss
|
||||
tqdm_dict = {'g_loss': g_loss}
|
||||
output = OrderedDict({
|
||||
'loss': g_loss,
|
||||
'progress_bar': tqdm_dict,
|
||||
'log': tqdm_dict
|
||||
})
|
||||
return output
|
||||
|
||||
# train discriminator
|
||||
if optimizer_i == 1:
|
||||
|
@ -133,8 +138,15 @@ class GAN(pl.LightningModule):
|
|||
|
||||
# discriminator loss is the average of these
|
||||
d_loss = (real_loss + fake_loss) / 2
|
||||
tqdm_dict = {'d_loss': d_loss}
|
||||
output = OrderedDict({
|
||||
'loss': d_loss,
|
||||
'progress_bar': tqdm_dict,
|
||||
'log': tqdm_dict
|
||||
})
|
||||
|
||||
return d_loss
|
||||
|
||||
return output
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.hparams.lr
|
||||
|
@ -152,16 +164,34 @@ class GAN(pl.LightningModule):
|
|||
dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform)
|
||||
return DataLoader(dataset, batch_size=self.hparams.batch_size)
|
||||
|
||||
def on_epoch_end(self):
|
||||
z = torch.randn(8, self.hparams.latent_dim)
|
||||
# match gpu device (or keep as cpu)
|
||||
if self.on_gpu:
|
||||
z = z.cuda(imgs.device.index)
|
||||
|
||||
# log sampled images
|
||||
sample_imgs = self.forward(z)
|
||||
grid = torchvision.utils.make_grid(sample_imgs)
|
||||
self.logger.experiment.add_image(f'generated_images', grid, self.current_epoch)
|
||||
|
||||
|
||||
|
||||
def main(hparams):
|
||||
# save tensorboard logs
|
||||
exp = Experiment(save_dir=os.getcwd())
|
||||
|
||||
# init model
|
||||
# ------------------------
|
||||
# 1 INIT LIGHTNING MODEL
|
||||
# ------------------------
|
||||
model = GAN(hparams)
|
||||
|
||||
# fit trainer on CPU
|
||||
trainer = pl.Trainer(experiment=exp, max_nb_epochs=200)
|
||||
# ------------------------
|
||||
# 2 INIT TRAINER
|
||||
# ------------------------
|
||||
trainer = pl.Trainer()
|
||||
|
||||
# ------------------------
|
||||
# 3 START TRAINING
|
||||
# ------------------------
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue