diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 5ce253c29c..e60a0d95cc 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -772,7 +772,7 @@ class Trainer(TrainerIO): output = self.model.training_step(data_batch, batch_nb) try: - model_specific_tqdm_metrics_dic = output['tqdm_metrics'] + model_specific_tqdm_metrics_dic = output['prog'] except Exception as e: model_specific_tqdm_metrics_dic = {} diff --git a/pytorch_lightning/testing_models/lm_test_module.py b/pytorch_lightning/testing_models/lm_test_module.py index 87de2c7777..8e0718b972 100644 --- a/pytorch_lightning/testing_models/lm_test_module.py +++ b/pytorch_lightning/testing_models/lm_test_module.py @@ -96,12 +96,15 @@ class LightningTestModel(LightningModule): if self.trainer.use_dp: loss_val = loss_val.unsqueeze(0) - output = OrderedDict({ - 'loss': loss_val - }) - - # can also return just a scalar instead of a dict (return loss_val) - return output + # alternate possible outputs to test + if self.trainer.batch_nb % 1 == 0: + output = OrderedDict({ + 'loss': loss_val, + 'prog': {'some_val': loss_val * loss_val} + }) + return output + if self.trainer.batch_nb % 2 == 0: + return loss_val def validation_step(self, data_batch, batch_i): """