updated args

This commit is contained in:
William Falcon 2019-06-25 20:10:23 -04:00
parent 2ac5cce67a
commit 7cb6e34beb
1 changed files with 2 additions and 2 deletions

View File

@ -65,12 +65,12 @@ class ExampleModel(RootModule):
"""
# forward pass
x, y = data_batch
print(x.shape)
x = x.view(x.size(0), -1)
y_hat = self.forward(x)
# calculate loss
loss_val = self.loss(y, y_hat)
print(loss_val)
# tqdm_dic = {'tng_loss': loss_val.item()}
# return loss_val, tqdm_dic
@ -87,11 +87,11 @@ class ExampleModel(RootModule):
:return:
"""
x, y = data_batch
print(x.shape)
x = x.view(x.size(0), -1)
y_hat = self.forward(x)
loss_val = self.loss(y, y_hat)
print(loss_val)
# acc
labels_hat = torch.argmax(y_hat, dim=1)