refactored model tests

This commit is contained in:
William Falcon 2019-07-24 13:55:20 -04:00
parent 0e9e07835c
commit cf7da86c7c
1 changed files with 5 additions and 11 deletions

View File

@ -54,8 +54,11 @@ class LightningTemplateModel(LightningModule):
:param x:
:return:
"""
print('-'*100)
print('x: ', x.device)
print('model: ', self.c_d1.weight.device)
print('-'*100)
print(x.device)
x = self.c_d1(x)
x = torch.tanh(x)
x = self.c_d1_bn(x)
@ -79,11 +82,7 @@ class LightningTemplateModel(LightningModule):
# forward pass
x, y = data_batch
x = x.view(x.size(0), -1)
print('-'*100)
print('TRAIN')
print('x: ', x.device)
print('model: ', self.c_d1.weight.device)
print('-'*100)
y_hat = self.forward(x)
# calculate loss
@ -104,11 +103,6 @@ class LightningTemplateModel(LightningModule):
"""
x, y = data_batch
x = x.view(x.size(0), -1)
print('-'*100)
print('VAL')
print('x: ', x.device, x.shape)
print('model: ', self.c_d1.weight.device, self.c_d1.bias.device)
print('-'*100)
y_hat = self.forward(x)
loss_val = self.loss(y, y_hat)