working on dp state fix
This commit is contained in:
parent
699bd2cb50
commit
e611223dc8
|
@ -1,8 +1,8 @@
|
|||
"""
|
||||
To run this template just do:
|
||||
python gan.py
|
||||
To run this template just do:
|
||||
python gan.py
|
||||
|
||||
After a few epochs, launch tensorboard to see the images being generated at every batch.
|
||||
After a few epochs, launch tensorboard to see the images being generated at every batch.
|
||||
|
||||
tensorboard --logdir default
|
||||
"""
|
||||
|
@ -83,6 +83,7 @@ class GAN(pl.LightningModule):
|
|||
self.discriminator = Discriminator(img_shape=mnist_shape)
|
||||
|
||||
# cache for generated images
|
||||
# TODO: in dp, the state gets wiped... need to figure out how to fix
|
||||
self.generated_imgs = None
|
||||
self.last_imgs = None
|
||||
|
||||
|
@ -90,7 +91,7 @@ class GAN(pl.LightningModule):
|
|||
return self.generator(z)
|
||||
|
||||
def adversarial_loss(self, y_hat, y):
|
||||
return F.binary_cross_entropy(y_hat, y)
|
||||
return F.binary_cross_entropy(y_hat, y.cuda())
|
||||
|
||||
def training_step(self, batch, batch_nb, optimizer_i):
|
||||
imgs, _ = batch
|
||||
|
@ -147,7 +148,6 @@ class GAN(pl.LightningModule):
|
|||
'log': tqdm_dict
|
||||
})
|
||||
|
||||
|
||||
return output
|
||||
|
||||
def configure_optimizers(self):
|
||||
|
@ -188,7 +188,7 @@ def main(hparams):
|
|||
# ------------------------
|
||||
# 2 INIT TRAINER
|
||||
# ------------------------
|
||||
trainer = pl.Trainer()
|
||||
trainer = pl.Trainer(distributed_backend='dp', gpus=2)
|
||||
|
||||
# ------------------------
|
||||
# 3 START TRAINING
|
||||
|
|
Loading…
Reference in New Issue