Bugfix/fix gan example (#2019)
* 🐛 fixed fake example type assigning and hparams arg
* fixed GAN example to work with dp, ddp., ddp_cpu
* Update generative_adversarial_net.py
Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
parent
0e37e8c4d2
commit
55fdfe3845
|
@ -79,6 +79,7 @@ class GAN(LightningModule):
|
|||
b2: float = 0.999,
|
||||
batch_size: int = 64, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
self.lr = lr
|
||||
self.b1 = b1
|
||||
|
@ -90,9 +91,7 @@ class GAN(LightningModule):
|
|||
self.generator = Generator(latent_dim=self.latent_dim, img_shape=mnist_shape)
|
||||
self.discriminator = Discriminator(img_shape=mnist_shape)
|
||||
|
||||
# cache for generated images
|
||||
self.generated_imgs = None
|
||||
self.last_imgs = None
|
||||
self.validation_z = torch.randn(8, self.latent_dim)
|
||||
|
||||
def forward(self, z):
|
||||
return self.generator(z)
|
||||
|
@ -102,21 +101,21 @@ class GAN(LightningModule):
|
|||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
imgs, _ = batch
|
||||
self.last_imgs = imgs
|
||||
|
||||
# sample noise
|
||||
z = torch.randn(imgs.shape[0], self.latent_dim)
|
||||
z = z.type_as(imgs)
|
||||
|
||||
# train generator
|
||||
if optimizer_idx == 0:
|
||||
# sample noise
|
||||
z = torch.randn(imgs.shape[0], self.latent_dim)
|
||||
z = z.type_as(imgs)
|
||||
|
||||
# generate images
|
||||
self.generated_imgs = self(z)
|
||||
|
||||
# log sampled images
|
||||
# sample_imgs = self.generated_imgs[:6]
|
||||
# grid = torchvision.utils.make_grid(sample_imgs)
|
||||
# self.logger.experiment.add_image('generated_images', grid, 0)
|
||||
sample_imgs = self.generated_imgs[:6]
|
||||
grid = torchvision.utils.make_grid(sample_imgs)
|
||||
self.logger.experiment.add_image('generated_images', grid, 0)
|
||||
|
||||
# ground truth result (ie: all fake)
|
||||
# put on GPU because we created this tensor inside training_loop
|
||||
|
@ -124,7 +123,7 @@ class GAN(LightningModule):
|
|||
valid = valid.type_as(imgs)
|
||||
|
||||
# adversarial loss is binary cross-entropy
|
||||
g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
|
||||
g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
|
||||
tqdm_dict = {'g_loss': g_loss}
|
||||
output = OrderedDict({
|
||||
'loss': g_loss,
|
||||
|
@ -145,10 +144,10 @@ class GAN(LightningModule):
|
|||
|
||||
# how well can it label as fake?
|
||||
fake = torch.zeros(imgs.size(0), 1)
|
||||
fake = fake.type_as(fake)
|
||||
fake = fake.type_as(imgs)
|
||||
|
||||
fake_loss = self.adversarial_loss(
|
||||
self.discriminator(self.generated_imgs.detach()), fake)
|
||||
self.discriminator(self(z).detach()), fake)
|
||||
|
||||
# discriminator loss is the average of these
|
||||
d_loss = (real_loss + fake_loss) / 2
|
||||
|
@ -176,8 +175,7 @@ class GAN(LightningModule):
|
|||
return DataLoader(dataset, batch_size=self.batch_size)
|
||||
|
||||
def on_epoch_end(self):
|
||||
z = torch.randn(8, self.latent_dim)
|
||||
z = z.type_as(self.last_imgs)
|
||||
z = self.validation_z.type_as(self.generator.model[0].weight)
|
||||
|
||||
# log sampled images
|
||||
sample_imgs = self(z)
|
||||
|
@ -185,15 +183,17 @@ class GAN(LightningModule):
|
|||
self.logger.experiment.add_image('generated_images', grid, self.current_epoch)
|
||||
|
||||
|
||||
def main(hparams):
|
||||
def main(args):
|
||||
# ------------------------
|
||||
# 1 INIT LIGHTNING MODEL
|
||||
# ------------------------
|
||||
model = GAN(hparams)
|
||||
model = GAN(**vars(args))
|
||||
|
||||
# ------------------------
|
||||
# 2 INIT TRAINER
|
||||
# ------------------------
|
||||
# If use distubuted training PyTorch recommends to use DistributedDataParallel.
|
||||
# See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
|
||||
trainer = Trainer()
|
||||
|
||||
# ------------------------
|
||||
|
|
Loading…
Reference in New Issue