diff --git a/pl_examples/domain_templates/gan.py b/pl_examples/domain_templates/gan.py index f86628a194..78a813e82f 100644 --- a/pl_examples/domain_templates/gan.py +++ b/pl_examples/domain_templates/gan.py @@ -112,7 +112,10 @@ class GAN(pl.LightningModule): # self.logger.experiment.add_image('generated_images', grid, 0) # ground truth result (ie: all fake) + # put on GPU because we created this tensor inside training_loop valid = torch.ones(imgs.size(0), 1) + if self.on_gpu: + valid = valid.cuda(imgs.device.index) # adversarial loss is binary cross-entropy g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid) @@ -130,10 +133,16 @@ class GAN(pl.LightningModule): # how well can it label as real? valid = torch.ones(imgs.size(0), 1) + if self.on_gpu: + valid = valid.cuda(imgs.device.index) + real_loss = self.adversarial_loss(self.discriminator(imgs), valid) # how well can it label as fake? fake = torch.zeros(imgs.size(0), 1) + if self.on_gpu: + fake = fake.cuda(imgs.device.index) + fake_loss = self.adversarial_loss( self.discriminator(self.generated_imgs.detach()), fake)