diff --git a/pl_examples/domain_templates/generative_adversarial_net.py b/pl_examples/domain_templates/generative_adversarial_net.py index 427312ecaf..65505c1240 100644 --- a/pl_examples/domain_templates/generative_adversarial_net.py +++ b/pl_examples/domain_templates/generative_adversarial_net.py @@ -20,6 +20,7 @@ from torch.utils.data import DataLoader from torchvision.datasets import MNIST from pytorch_lightning.core import LightningModule +from pytorch_lightning import TrainResult from pytorch_lightning.trainer import Trainer @@ -127,12 +128,13 @@ class GAN(LightningModule): # adversarial loss is binary cross-entropy g_loss = self.adversarial_loss(self.discriminator(self(z)), valid) tqdm_dict = {'g_loss': g_loss} - output = OrderedDict({ - 'loss': g_loss, - 'progress_bar': tqdm_dict, - 'log': tqdm_dict - }) - return output + result = TrainResult( + minimize=g_loss, + checkpoint_on=True + ) + result.log_dict(tqdm_dict) + + return result # train discriminator if optimizer_idx == 1: @@ -154,12 +156,13 @@ class GAN(LightningModule): # discriminator loss is the average of these d_loss = (real_loss + fake_loss) / 2 tqdm_dict = {'d_loss': d_loss} - output = OrderedDict({ - 'loss': d_loss, - 'progress_bar': tqdm_dict, - 'log': tqdm_dict - }) - return output + result = TrainResult( + minimize=d_loss, + checkpoint_on=True + ) + result.log_dict(tqdm_dict) + + return result def configure_optimizers(self): lr = self.lr