refactored model tests
This commit is contained in:
parent
0e9e07835c
commit
cf7da86c7c
pytorch_lightning/examples/new_project_templates
|
@ -54,8 +54,11 @@ class LightningTemplateModel(LightningModule):
|
||||||
:param x:
|
:param x:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
print('-'*100)
|
||||||
|
print('x: ', x.device)
|
||||||
|
print('model: ', self.c_d1.weight.device)
|
||||||
|
print('-'*100)
|
||||||
|
|
||||||
print(x.device)
|
|
||||||
x = self.c_d1(x)
|
x = self.c_d1(x)
|
||||||
x = torch.tanh(x)
|
x = torch.tanh(x)
|
||||||
x = self.c_d1_bn(x)
|
x = self.c_d1_bn(x)
|
||||||
|
@ -79,11 +82,7 @@ class LightningTemplateModel(LightningModule):
|
||||||
# forward pass
|
# forward pass
|
||||||
x, y = data_batch
|
x, y = data_batch
|
||||||
x = x.view(x.size(0), -1)
|
x = x.view(x.size(0), -1)
|
||||||
print('-'*100)
|
|
||||||
print('TRAIN')
|
|
||||||
print('x: ', x.device)
|
|
||||||
print('model: ', self.c_d1.weight.device)
|
|
||||||
print('-'*100)
|
|
||||||
y_hat = self.forward(x)
|
y_hat = self.forward(x)
|
||||||
|
|
||||||
# calculate loss
|
# calculate loss
|
||||||
|
@ -104,11 +103,6 @@ class LightningTemplateModel(LightningModule):
|
||||||
"""
|
"""
|
||||||
x, y = data_batch
|
x, y = data_batch
|
||||||
x = x.view(x.size(0), -1)
|
x = x.view(x.size(0), -1)
|
||||||
print('-'*100)
|
|
||||||
print('VAL')
|
|
||||||
print('x: ', x.device, x.shape)
|
|
||||||
print('model: ', self.c_d1.weight.device, self.c_d1.bias.device)
|
|
||||||
print('-'*100)
|
|
||||||
y_hat = self.forward(x)
|
y_hat = self.forward(x)
|
||||||
|
|
||||||
loss_val = self.loss(y, y_hat)
|
loss_val = self.loss(y, y_hat)
|
||||||
|
|
Loading…
Reference in New Issue