diff --git a/pytorch_lightning/testing_models/lm_test_module.py b/pytorch_lightning/testing_models/lm_test_module.py index 87de2c7777..8e0718b972 100644 --- a/pytorch_lightning/testing_models/lm_test_module.py +++ b/pytorch_lightning/testing_models/lm_test_module.py @@ -96,12 +96,15 @@ class LightningTestModel(LightningModule): if self.trainer.use_dp: loss_val = loss_val.unsqueeze(0) - output = OrderedDict({ - 'loss': loss_val - }) - - # can also return just a scalar instead of a dict (return loss_val) - return output + # alternate possible outputs to test + if self.trainer.batch_nb % 1 == 0: + output = OrderedDict({ + 'loss': loss_val, + 'prog': {'some_val': loss_val * loss_val} + }) + return output + if self.trainer.batch_nb % 2 == 0: + return loss_val def validation_step(self, data_batch, batch_i): """