diff --git a/examples/domain_templates/gan.py b/examples/domain_templates/gan.py index a93cba847b..1e2c89f510 100644 --- a/examples/domain_templates/gan.py +++ b/examples/domain_templates/gan.py @@ -83,7 +83,6 @@ class GAN(pl.LightningModule): self.discriminator = Discriminator(img_shape=mnist_shape) # cache for generated images - # TODO: in dp, the state gets wiped... need to figure out how to fix self.generated_imgs = None self.last_imgs = None @@ -91,7 +90,7 @@ class GAN(pl.LightningModule): return self.generator(z) def adversarial_loss(self, y_hat, y): - return F.binary_cross_entropy(y_hat, y.cuda()) + return F.binary_cross_entropy(y_hat, y) def training_step(self, batch, batch_nb, optimizer_i): imgs, _ = batch @@ -148,6 +147,7 @@ class GAN(pl.LightningModule): 'log': tqdm_dict }) + return output def configure_optimizers(self): @@ -188,7 +188,7 @@ def main(hparams): # ------------------------ # 2 INIT TRAINER # ------------------------ - trainer = pl.Trainer(distributed_backend='dp', gpus=2) + trainer = pl.Trainer() # ------------------------ # 3 START TRAINING