From d77914e466ea8bbd077e467ff9eed75d90447338 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 24 Jul 2019 09:29:46 -0400 Subject: [PATCH] updated reqs --- .../new_project_templates/lightning_module_template.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/examples/new_project_templates/lightning_module_template.py b/pytorch_lightning/examples/new_project_templates/lightning_module_template.py index df7dbe328e..490ccebeb5 100644 --- a/pytorch_lightning/examples/new_project_templates/lightning_module_template.py +++ b/pytorch_lightning/examples/new_project_templates/lightning_module_template.py @@ -105,10 +105,14 @@ class LightningTemplateModel(LightningModule): # acc labels_hat = torch.argmax(y_hat, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + val_acc = torch.tensor(val_acc) + + if self.on_gpu: + val_acc = val_acc.cuda(loss_val.device.index) output = OrderedDict({ 'val_loss': loss_val, - 'val_acc': torch.tensor(val_acc).type(loss_val.dtype), + 'val_acc': val_acc, }) # can also return just a scalar instead of a dict (return loss_val)