working on dp state fix

This commit is contained in:
William Falcon 2019-10-19 00:26:50 +02:00
parent e611223dc8
commit e1b45ca492
1 changed files with 3 additions and 3 deletions

View File

@ -83,7 +83,6 @@ class GAN(pl.LightningModule):
self.discriminator = Discriminator(img_shape=mnist_shape)
# cache for generated images
# TODO: in dp, the state gets wiped... need to figure out how to fix
self.generated_imgs = None
self.last_imgs = None
@ -91,7 +90,7 @@ class GAN(pl.LightningModule):
return self.generator(z)
def adversarial_loss(self, y_hat, y):
return F.binary_cross_entropy(y_hat, y.cuda())
return F.binary_cross_entropy(y_hat, y)
def training_step(self, batch, batch_nb, optimizer_i):
imgs, _ = batch
@ -148,6 +147,7 @@ class GAN(pl.LightningModule):
'log': tqdm_dict
})
return output
def configure_optimizers(self):
@ -188,7 +188,7 @@ def main(hparams):
# ------------------------
# 2 INIT TRAINER
# ------------------------
trainer = pl.Trainer(distributed_backend='dp', gpus=2)
trainer = pl.Trainer()
# ------------------------
# 3 START TRAINING