parent
218f0a5b4a
commit
6ba30a113d
|
@ -112,7 +112,10 @@ class GAN(pl.LightningModule):
|
|||
# self.logger.experiment.add_image('generated_images', grid, 0)
|
||||
|
||||
# ground truth result (ie: all fake)
|
||||
# put on GPU because we created this tensor inside training_loop
|
||||
valid = torch.ones(imgs.size(0), 1)
|
||||
if self.on_gpu:
|
||||
valid = valid.cuda(imgs.device.index)
|
||||
|
||||
# adversarial loss is binary cross-entropy
|
||||
g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
|
||||
|
@ -130,10 +133,16 @@ class GAN(pl.LightningModule):
|
|||
|
||||
# how well can it label as real?
|
||||
valid = torch.ones(imgs.size(0), 1)
|
||||
if self.on_gpu:
|
||||
valid = valid.cuda(imgs.device.index)
|
||||
|
||||
real_loss = self.adversarial_loss(self.discriminator(imgs), valid)
|
||||
|
||||
# how well can it label as fake?
|
||||
fake = torch.zeros(imgs.size(0), 1)
|
||||
if self.on_gpu:
|
||||
fake = fake.cuda(imgs.device.index)
|
||||
|
||||
fake_loss = self.adversarial_loss(
|
||||
self.discriminator(self.generated_imgs.detach()), fake)
|
||||
|
||||
|
|
Loading…
Reference in New Issue