From d372f9a2e207f390d2154a6bd20c84b90e4a4715 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 28 Jul 2019 11:46:26 -0400 Subject: [PATCH 1/2] updated dict keys --- pytorch_lightning/models/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 5ce253c29c..e60a0d95cc 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -772,7 +772,7 @@ class Trainer(TrainerIO): output = self.model.training_step(data_batch, batch_nb) try: - model_specific_tqdm_metrics_dic = output['tqdm_metrics'] + model_specific_tqdm_metrics_dic = output['prog'] except Exception as e: model_specific_tqdm_metrics_dic = {} From 6bb3c0306a1413da18b89707642028efa1966861 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 28 Jul 2019 11:51:32 -0400 Subject: [PATCH 2/2] updated output of test models --- .../testing_models/lm_test_module.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/testing_models/lm_test_module.py b/pytorch_lightning/testing_models/lm_test_module.py index 87de2c7777..8e0718b972 100644 --- a/pytorch_lightning/testing_models/lm_test_module.py +++ b/pytorch_lightning/testing_models/lm_test_module.py @@ -96,12 +96,15 @@ class LightningTestModel(LightningModule): if self.trainer.use_dp: loss_val = loss_val.unsqueeze(0) - output = OrderedDict({ - 'loss': loss_val - }) - - # can also return just a scalar instead of a dict (return loss_val) - return output + # alternate possible outputs to test + if self.trainer.batch_nb % 1 == 0: + output = OrderedDict({ + 'loss': loss_val, + 'prog': {'some_val': loss_val * loss_val} + }) + return output + if self.trainer.batch_nb % 2 == 0: + return loss_val def validation_step(self, data_batch, batch_i): """