GANs in pl-examples updated for lightning-0.9 (#3152)
* gan updated for lightning-0.9 * bugs fixed
This commit is contained in:
parent
5f39ae804a
commit
f22292c5f2
|
@ -20,6 +20,7 @@ from torch.utils.data import DataLoader
|
|||
from torchvision.datasets import MNIST
|
||||
|
||||
from pytorch_lightning.core import LightningModule
|
||||
from pytorch_lightning import TrainResult
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
|
||||
|
||||
|
@ -127,12 +128,13 @@ class GAN(LightningModule):
|
|||
# adversarial loss is binary cross-entropy
|
||||
g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
|
||||
tqdm_dict = {'g_loss': g_loss}
|
||||
output = OrderedDict({
|
||||
'loss': g_loss,
|
||||
'progress_bar': tqdm_dict,
|
||||
'log': tqdm_dict
|
||||
})
|
||||
return output
|
||||
result = TrainResult(
|
||||
minimize=g_loss,
|
||||
checkpoint_on=True
|
||||
)
|
||||
result.log_dict(tqdm_dict)
|
||||
|
||||
return result
|
||||
|
||||
# train discriminator
|
||||
if optimizer_idx == 1:
|
||||
|
@ -154,12 +156,13 @@ class GAN(LightningModule):
|
|||
# discriminator loss is the average of these
|
||||
d_loss = (real_loss + fake_loss) / 2
|
||||
tqdm_dict = {'d_loss': d_loss}
|
||||
output = OrderedDict({
|
||||
'loss': d_loss,
|
||||
'progress_bar': tqdm_dict,
|
||||
'log': tqdm_dict
|
||||
})
|
||||
return output
|
||||
result = TrainResult(
|
||||
minimize=d_loss,
|
||||
checkpoint_on=True
|
||||
)
|
||||
result.log_dict(tqdm_dict)
|
||||
|
||||
return result
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.lr
|
||||
|
|
Loading…
Reference in New Issue