updated args
This commit is contained in:
parent
2ac5cce67a
commit
7cb6e34beb
|
@ -65,12 +65,12 @@ class ExampleModel(RootModule):
|
|||
"""
|
||||
# forward pass
|
||||
x, y = data_batch
|
||||
print(x.shape)
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
# calculate loss
|
||||
loss_val = self.loss(y, y_hat)
|
||||
print(loss_val)
|
||||
|
||||
# tqdm_dic = {'tng_loss': loss_val.item()}
|
||||
# return loss_val, tqdm_dic
|
||||
|
@ -87,11 +87,11 @@ class ExampleModel(RootModule):
|
|||
:return:
|
||||
"""
|
||||
x, y = data_batch
|
||||
print(x.shape)
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
loss_val = self.loss(y, y_hat)
|
||||
print(loss_val)
|
||||
|
||||
# acc
|
||||
labels_hat = torch.argmax(y_hat, dim=1)
|
||||
|
|
Loading…
Reference in New Issue