diff --git a/docs/source/examples/example_model.py b/docs/source/examples/example_model.py index a8954d904a..416344a9da 100644 --- a/docs/source/examples/example_model.py +++ b/docs/source/examples/example_model.py @@ -76,7 +76,7 @@ class ExampleModel(RootModule): # return loss_val, tqdm_dic output = OrderedDict({ - 'loss_val': loss_val, + 'loss_val': loss_val.unsqueeze(0), }) return output @@ -99,8 +99,8 @@ class ExampleModel(RootModule): # output = {'y_hat': y_hat, 'val_loss': loss_val.item(), 'val_acc': val_acc} output = OrderedDict({ - 'loss_val': loss_val, - 'val_acc': val_acc, + 'loss_val': loss_val.unsqueeze(0), + 'val_acc': val_acc.unsqueeze(0), }) return output