From cf7da86c7c15afdb503e78af045bd22463e2e9f7 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 24 Jul 2019 13:55:20 -0400 Subject: [PATCH] refactored model tests --- .../lightning_module_template.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/examples/new_project_templates/lightning_module_template.py b/pytorch_lightning/examples/new_project_templates/lightning_module_template.py index 16f588aeb6..db065deff5 100644 --- a/pytorch_lightning/examples/new_project_templates/lightning_module_template.py +++ b/pytorch_lightning/examples/new_project_templates/lightning_module_template.py @@ -54,8 +54,11 @@ class LightningTemplateModel(LightningModule): :param x: :return: """ + print('-'*100) + print('x: ', x.device) + print('model: ', self.c_d1.weight.device) + print('-'*100) - print(x.device) x = self.c_d1(x) x = torch.tanh(x) x = self.c_d1_bn(x) @@ -79,11 +82,7 @@ class LightningTemplateModel(LightningModule): # forward pass x, y = data_batch x = x.view(x.size(0), -1) - print('-'*100) - print('TRAIN') - print('x: ', x.device) - print('model: ', self.c_d1.weight.device) - print('-'*100) + y_hat = self.forward(x) # calculate loss @@ -104,11 +103,6 @@ class LightningTemplateModel(LightningModule): """ x, y = data_batch x = x.view(x.size(0), -1) - print('-'*100) - print('VAL') - print('x: ', x.device, x.shape) - print('model: ', self.c_d1.weight.device, self.c_d1.bias.device) - print('-'*100) y_hat = self.forward(x) loss_val = self.loss(y, y_hat)