From e611223dc812497fdde10eeb2bc96c516a168df6 Mon Sep 17 00:00:00 2001 From: William Falcon <waf2107@columbia.edu> Date: Sat, 19 Oct 2019 00:26:31 +0200 Subject: [PATCH] working on dp state fix --- examples/domain_templates/gan.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/domain_templates/gan.py b/examples/domain_templates/gan.py index 5a94fcd4bc..a93cba847b 100644 --- a/examples/domain_templates/gan.py +++ b/examples/domain_templates/gan.py @@ -1,8 +1,8 @@ """ -To run this template just do: -python gan.py +To run this template just do: +python gan.py -After a few epochs, launch tensorboard to see the images being generated at every batch. +After a few epochs, launch tensorboard to see the images being generated at every batch. tensorboard --logdir default """ @@ -83,6 +83,7 @@ class GAN(pl.LightningModule): self.discriminator = Discriminator(img_shape=mnist_shape) # cache for generated images + # TODO: in dp, the state gets wiped... need to figure out how to fix self.generated_imgs = None self.last_imgs = None @@ -90,7 +91,7 @@ class GAN(pl.LightningModule): return self.generator(z) def adversarial_loss(self, y_hat, y): - return F.binary_cross_entropy(y_hat, y) + return F.binary_cross_entropy(y_hat, y.cuda()) def training_step(self, batch, batch_nb, optimizer_i): imgs, _ = batch @@ -147,7 +148,6 @@ class GAN(pl.LightningModule): 'log': tqdm_dict }) - return output def configure_optimizers(self): @@ -188,7 +188,7 @@ def main(hparams): # ------------------------ # 2 INIT TRAINER # ------------------------ - trainer = pl.Trainer() + trainer = pl.Trainer(distributed_backend='dp', gpus=2) # ------------------------ # 3 START TRAINING