From 6bb3c0306a1413da18b89707642028efa1966861 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 28 Jul 2019 11:51:32 -0400 Subject: [PATCH] updated output of test models --- .../testing_models/lm_test_module.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/testing_models/lm_test_module.py b/pytorch_lightning/testing_models/lm_test_module.py index 87de2c7777..8e0718b972 100644 --- a/pytorch_lightning/testing_models/lm_test_module.py +++ b/pytorch_lightning/testing_models/lm_test_module.py @@ -96,12 +96,15 @@ class LightningTestModel(LightningModule): if self.trainer.use_dp: loss_val = loss_val.unsqueeze(0) - output = OrderedDict({ - 'loss': loss_val - }) - - # can also return just a scalar instead of a dict (return loss_val) - return output + # alternate possible outputs to test + if self.trainer.batch_nb % 1 == 0: + output = OrderedDict({ + 'loss': loss_val, + 'prog': {'some_val': loss_val * loss_val} + }) + return output + if self.trainer.batch_nb % 2 == 0: + return loss_val def validation_step(self, data_batch, batch_i): """