From e611223dc812497fdde10eeb2bc96c516a168df6 Mon Sep 17 00:00:00 2001
From: William Falcon <waf2107@columbia.edu>
Date: Sat, 19 Oct 2019 00:26:31 +0200
Subject: [PATCH] working on dp state fix

---
 examples/domain_templates/gan.py | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/examples/domain_templates/gan.py b/examples/domain_templates/gan.py
index 5a94fcd4bc..a93cba847b 100644
--- a/examples/domain_templates/gan.py
+++ b/examples/domain_templates/gan.py
@@ -1,8 +1,8 @@
 """
-To run this template just do:  
-python gan.py   
+To run this template just do:
+python gan.py
 
-After a few epochs, launch tensorboard to see the images being generated at every batch.   
+After a few epochs, launch tensorboard to see the images being generated at every batch.
 
 tensorboard --logdir default
 """
@@ -83,6 +83,7 @@ class GAN(pl.LightningModule):
         self.discriminator = Discriminator(img_shape=mnist_shape)
 
         # cache for generated images
+        # TODO: in dp, the state gets wiped... need to figure out how to fix
         self.generated_imgs = None
         self.last_imgs = None
 
@@ -90,7 +91,7 @@ class GAN(pl.LightningModule):
         return self.generator(z)
 
     def adversarial_loss(self, y_hat, y):
-        return F.binary_cross_entropy(y_hat, y)
+        return F.binary_cross_entropy(y_hat, y.cuda())
 
     def training_step(self, batch, batch_nb, optimizer_i):
         imgs, _ = batch
@@ -147,7 +148,6 @@ class GAN(pl.LightningModule):
                 'log': tqdm_dict
             })
 
-
             return output
 
     def configure_optimizers(self):
@@ -188,7 +188,7 @@ def main(hparams):
     # ------------------------
     # 2 INIT TRAINER
     # ------------------------
-    trainer = pl.Trainer()
+    trainer = pl.Trainer(distributed_backend='dp', gpus=2)
 
     # ------------------------
     # 3 START TRAINING