Bugfix/fix gan example (#2019)

* 🐛 fixed fake example type assigning and hparams arg

* fixed GAN example to work with dp, ddp., ddp_cpu

* Update generative_adversarial_net.py

Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
Artem Lobantsev 2020-05-31 15:31:21 +03:00 committed by GitHub
parent 0e37e8c4d2
commit 55fdfe3845
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 17 additions and 17 deletions

View File

@ -79,6 +79,7 @@ class GAN(LightningModule):
b2: float = 0.999,
batch_size: int = 64, **kwargs):
super().__init__()
self.latent_dim = latent_dim
self.lr = lr
self.b1 = b1
@ -90,9 +91,7 @@ class GAN(LightningModule):
self.generator = Generator(latent_dim=self.latent_dim, img_shape=mnist_shape)
self.discriminator = Discriminator(img_shape=mnist_shape)
# cache for generated images
self.generated_imgs = None
self.last_imgs = None
self.validation_z = torch.randn(8, self.latent_dim)
def forward(self, z):
return self.generator(z)
@ -102,21 +101,21 @@ class GAN(LightningModule):
def training_step(self, batch, batch_idx, optimizer_idx):
imgs, _ = batch
self.last_imgs = imgs
# sample noise
z = torch.randn(imgs.shape[0], self.latent_dim)
z = z.type_as(imgs)
# train generator
if optimizer_idx == 0:
# sample noise
z = torch.randn(imgs.shape[0], self.latent_dim)
z = z.type_as(imgs)
# generate images
self.generated_imgs = self(z)
# log sampled images
# sample_imgs = self.generated_imgs[:6]
# grid = torchvision.utils.make_grid(sample_imgs)
# self.logger.experiment.add_image('generated_images', grid, 0)
sample_imgs = self.generated_imgs[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('generated_images', grid, 0)
# ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop
@ -124,7 +123,7 @@ class GAN(LightningModule):
valid = valid.type_as(imgs)
# adversarial loss is binary cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
tqdm_dict = {'g_loss': g_loss}
output = OrderedDict({
'loss': g_loss,
@ -145,10 +144,10 @@ class GAN(LightningModule):
# how well can it label as fake?
fake = torch.zeros(imgs.size(0), 1)
fake = fake.type_as(fake)
fake = fake.type_as(imgs)
fake_loss = self.adversarial_loss(
self.discriminator(self.generated_imgs.detach()), fake)
self.discriminator(self(z).detach()), fake)
# discriminator loss is the average of these
d_loss = (real_loss + fake_loss) / 2
@ -176,8 +175,7 @@ class GAN(LightningModule):
return DataLoader(dataset, batch_size=self.batch_size)
def on_epoch_end(self):
z = torch.randn(8, self.latent_dim)
z = z.type_as(self.last_imgs)
z = self.validation_z.type_as(self.generator.model[0].weight)
# log sampled images
sample_imgs = self(z)
@ -185,15 +183,17 @@ class GAN(LightningModule):
self.logger.experiment.add_image('generated_images', grid, self.current_epoch)
def main(hparams):
def main(args):
# ------------------------
# 1 INIT LIGHTNING MODEL
# ------------------------
model = GAN(hparams)
model = GAN(**vars(args))
# ------------------------
# 2 INIT TRAINER
# ------------------------
# If use distubuted training PyTorch recommends to use DistributedDataParallel.
# See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
trainer = Trainer()
# ------------------------