diff --git a/docs/source/examples/example_model.py b/docs/source/examples/example_model.py index 2cc2d5a070..c3d64e6769 100644 --- a/docs/source/examples/example_model.py +++ b/docs/source/examples/example_model.py @@ -65,12 +65,12 @@ class ExampleModel(RootModule): """ # forward pass x, y = data_batch - print(x.shape) x = x.view(x.size(0), -1) y_hat = self.forward(x) # calculate loss loss_val = self.loss(y, y_hat) + print(loss_val) # tqdm_dic = {'tng_loss': loss_val.item()} # return loss_val, tqdm_dic @@ -87,11 +87,11 @@ class ExampleModel(RootModule): :return: """ x, y = data_batch - print(x.shape) x = x.view(x.size(0), -1) y_hat = self.forward(x) loss_val = self.loss(y, y_hat) + print(loss_val) # acc labels_hat = torch.argmax(y_hat, dim=1)