working on dp state fix

This commit is contained in:
William Falcon 2019-10-19 00:26:31 +02:00
parent 699bd2cb50
commit e611223dc8
1 changed files with 6 additions and 6 deletions

View File

@ -1,8 +1,8 @@
"""
To run this template just do:
python gan.py
To run this template just do:
python gan.py
After a few epochs, launch tensorboard to see the images being generated at every batch.
After a few epochs, launch tensorboard to see the images being generated at every batch.
tensorboard --logdir default
"""
@ -83,6 +83,7 @@ class GAN(pl.LightningModule):
self.discriminator = Discriminator(img_shape=mnist_shape)
# cache for generated images
# TODO: in dp, the state gets wiped... need to figure out how to fix
self.generated_imgs = None
self.last_imgs = None
@ -90,7 +91,7 @@ class GAN(pl.LightningModule):
return self.generator(z)
def adversarial_loss(self, y_hat, y):
return F.binary_cross_entropy(y_hat, y)
return F.binary_cross_entropy(y_hat, y.cuda())
def training_step(self, batch, batch_nb, optimizer_i):
imgs, _ = batch
@ -147,7 +148,6 @@ class GAN(pl.LightningModule):
'log': tqdm_dict
})
return output
def configure_optimizers(self):
@ -188,7 +188,7 @@ def main(hparams):
# ------------------------
# 2 INIT TRAINER
# ------------------------
trainer = pl.Trainer()
trainer = pl.Trainer(distributed_backend='dp', gpus=2)
# ------------------------
# 3 START TRAINING