GANs in pl-examples updated for lightning-0.9 (#3152)

* gan updated for lightning-0.9

* bugs fixed
This commit is contained in:
Vasudev Gupta 2020-08-25 20:35:03 +05:30 committed by GitHub
parent 5f39ae804a
commit f22292c5f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 12 deletions

View File

@ -20,6 +20,7 @@ from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from pytorch_lightning.core import LightningModule
from pytorch_lightning import TrainResult
from pytorch_lightning.trainer import Trainer
@ -127,12 +128,13 @@ class GAN(LightningModule):
# adversarial loss is binary cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
tqdm_dict = {'g_loss': g_loss}
output = OrderedDict({
'loss': g_loss,
'progress_bar': tqdm_dict,
'log': tqdm_dict
})
return output
result = TrainResult(
minimize=g_loss,
checkpoint_on=True
)
result.log_dict(tqdm_dict)
return result
# train discriminator
if optimizer_idx == 1:
@ -154,12 +156,13 @@ class GAN(LightningModule):
# discriminator loss is the average of these
d_loss = (real_loss + fake_loss) / 2
tqdm_dict = {'d_loss': d_loss}
output = OrderedDict({
'loss': d_loss,
'progress_bar': tqdm_dict,
'log': tqdm_dict
})
return output
result = TrainResult(
minimize=d_loss,
checkpoint_on=True
)
result.log_dict(tqdm_dict)
return result
def configure_optimizers(self):
lr = self.lr