From e1b45ca492f812b91695ae2d6d55830e5a8e0ec5 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 19 Oct 2019 00:26:50 +0200 Subject: [PATCH] working on dp state fix --- examples/domain_templates/gan.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/domain_templates/gan.py b/examples/domain_templates/gan.py index a93cba847b..1e2c89f510 100644 --- a/examples/domain_templates/gan.py +++ b/examples/domain_templates/gan.py @@ -83,7 +83,6 @@ class GAN(pl.LightningModule): self.discriminator = Discriminator(img_shape=mnist_shape) # cache for generated images - # TODO: in dp, the state gets wiped... need to figure out how to fix self.generated_imgs = None self.last_imgs = None @@ -91,7 +90,7 @@ class GAN(pl.LightningModule): return self.generator(z) def adversarial_loss(self, y_hat, y): - return F.binary_cross_entropy(y_hat, y.cuda()) + return F.binary_cross_entropy(y_hat, y) def training_step(self, batch, batch_nb, optimizer_i): imgs, _ = batch @@ -148,6 +147,7 @@ class GAN(pl.LightningModule): 'log': tqdm_dict }) + return output def configure_optimizers(self): @@ -188,7 +188,7 @@ def main(hparams): # ------------------------ # 2 INIT TRAINER # ------------------------ - trainer = pl.Trainer(distributed_backend='dp', gpus=2) + trainer = pl.Trainer() # ------------------------ # 3 START TRAINING