diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 8fc4d5ee66..b27996a797 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -597,7 +597,7 @@ class Trainer(TrainerIO): try: loss = output['loss'] except Exception as e: - if type(loss) is torch.Tensor: + if type(output) is torch.Tensor: loss = output self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)