refactored model tests
This commit is contained in:
parent
0e9e07835c
commit
cf7da86c7c
|
@ -54,8 +54,11 @@ class LightningTemplateModel(LightningModule):
|
|||
:param x:
|
||||
:return:
|
||||
"""
|
||||
print('-'*100)
|
||||
print('x: ', x.device)
|
||||
print('model: ', self.c_d1.weight.device)
|
||||
print('-'*100)
|
||||
|
||||
print(x.device)
|
||||
x = self.c_d1(x)
|
||||
x = torch.tanh(x)
|
||||
x = self.c_d1_bn(x)
|
||||
|
@ -79,11 +82,7 @@ class LightningTemplateModel(LightningModule):
|
|||
# forward pass
|
||||
x, y = data_batch
|
||||
x = x.view(x.size(0), -1)
|
||||
print('-'*100)
|
||||
print('TRAIN')
|
||||
print('x: ', x.device)
|
||||
print('model: ', self.c_d1.weight.device)
|
||||
print('-'*100)
|
||||
|
||||
y_hat = self.forward(x)
|
||||
|
||||
# calculate loss
|
||||
|
@ -104,11 +103,6 @@ class LightningTemplateModel(LightningModule):
|
|||
"""
|
||||
x, y = data_batch
|
||||
x = x.view(x.size(0), -1)
|
||||
print('-'*100)
|
||||
print('VAL')
|
||||
print('x: ', x.device, x.shape)
|
||||
print('model: ', self.c_d1.weight.device, self.c_d1.bias.device)
|
||||
print('-'*100)
|
||||
y_hat = self.forward(x)
|
||||
|
||||
loss_val = self.loss(y, y_hat)
|
||||
|
|
Loading…
Reference in New Issue