From 0929908229de524a5886b7a3534dc0c909bcf5a2 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 11 Jul 2019 15:08:45 -0400 Subject: [PATCH] simplify trainer output --- pytorch_lightning/models/trainer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index ad662d7503..39738e0a9b 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -611,10 +611,16 @@ class Trainer(TrainerIO): try: model_specific_tqdm_metrics_dic = output['tqdm_metrics'] - except TypeError as e: + except Exception as e: model_specific_tqdm_metrics_dic = {} - loss = output['loss'] + # if output dict doesn't have the keyword loss + # then assume the output=loss if scalar + try: + loss = output['loss'] + except Exception as e: + if type(loss) is torch.Tensor: + loss = output self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)