diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 059eb09abd..b6b241d72e 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -132,6 +132,7 @@ class Trainer(TrainerIO): 'batch_nb':'{}'.format(self.batch_nb), } tqdm_dic.update(self.tqdm_metrics) + return tqdm_dic def __layout_bookeeping(self, model): @@ -161,6 +162,9 @@ class Trainer(TrainerIO): def __add_tqdm_metrics(self, metrics): for k, v in metrics.items(): + if type(v) is torch.Tensor: + v = v.item() + self.tqdm_metrics[k] = v def validate(self, model, dataloader, max_batches):