updated reqs

This commit is contained in:
William Falcon 2019-07-24 09:29:46 -04:00
parent 0cf9fa1a60
commit d77914e466
1 changed files with 5 additions and 1 deletions

View File

@ -105,10 +105,14 @@ class LightningTemplateModel(LightningModule):
# acc
labels_hat = torch.argmax(y_hat, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
val_acc = torch.tensor(val_acc)
if self.on_gpu:
val_acc = val_acc.cuda(loss_val.device.index)
output = OrderedDict({
'val_loss': loss_val,
'val_acc': torch.tensor(val_acc).type(loss_val.dtype),
'val_acc': val_acc,
})
# can also return just a scalar instead of a dict (return loss_val)