updated reqs
This commit is contained in:
parent
0cf9fa1a60
commit
d77914e466
|
@ -105,10 +105,14 @@ class LightningTemplateModel(LightningModule):
|
|||
# acc
|
||||
labels_hat = torch.argmax(y_hat, dim=1)
|
||||
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
||||
val_acc = torch.tensor(val_acc)
|
||||
|
||||
if self.on_gpu:
|
||||
val_acc = val_acc.cuda(loss_val.device.index)
|
||||
|
||||
output = OrderedDict({
|
||||
'val_loss': loss_val,
|
||||
'val_acc': torch.tensor(val_acc).type(loss_val.dtype),
|
||||
'val_acc': val_acc,
|
||||
})
|
||||
|
||||
# can also return just a scalar instead of a dict (return loss_val)
|
||||
|
|
Loading…
Reference in New Issue