fixed gan template (#528)

* fixed gan template

* Update gan.py
This commit is contained in:
William Falcon 2019-12-04 08:28:46 -05:00 committed by GitHub
parent 218f0a5b4a
commit 6ba30a113d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 0 deletions

View File

@ -112,7 +112,10 @@ class GAN(pl.LightningModule):
# self.logger.experiment.add_image('generated_images', grid, 0)
# ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop
valid = torch.ones(imgs.size(0), 1)
if self.on_gpu:
valid = valid.cuda(imgs.device.index)
# adversarial loss is binary cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
@ -130,10 +133,16 @@ class GAN(pl.LightningModule):
# how well can it label as real?
valid = torch.ones(imgs.size(0), 1)
if self.on_gpu:
valid = valid.cuda(imgs.device.index)
real_loss = self.adversarial_loss(self.discriminator(imgs), valid)
# how well can it label as fake?
fake = torch.zeros(imgs.size(0), 1)
if self.on_gpu:
fake = fake.cuda(imgs.device.index)
fake_loss = self.adversarial_loss(
self.discriminator(self.generated_imgs.detach()), fake)