Fix for accuracy calculation (#2183)
* accuracy_fix * fix line length * Apply suggestions from code review * Update test_classification.py Co-authored-by: Nicki Skafte <nugginea@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
parent
c0903b800d
commit
02262d0a93
|
@ -96,8 +96,9 @@ def stat_scores(
|
|||
fp = ((pred == class_index) * (target != class_index)).to(torch.long).sum()
|
||||
tn = ((pred != class_index) * (target != class_index)).to(torch.long).sum()
|
||||
fn = ((pred != class_index) * (target == class_index)).to(torch.long).sum()
|
||||
sup = (target == class_index).to(torch.long).sum()
|
||||
|
||||
return tp, fp, tn, fn
|
||||
return tp, fp, tn, fn, sup
|
||||
|
||||
|
||||
def stat_scores_multiple_classes(
|
||||
|
@ -132,12 +133,13 @@ def stat_scores_multiple_classes(
|
|||
fps = torch.zeros((num_classes,), device=pred.device)
|
||||
tns = torch.zeros((num_classes,), device=pred.device)
|
||||
fns = torch.zeros((num_classes,), device=pred.device)
|
||||
|
||||
sups = torch.zeros((num_classes,), device=pred.device)
|
||||
for c in range(num_classes):
|
||||
tps[c], fps[c], tns[c], fns[c] = stat_scores(pred=pred, target=target,
|
||||
class_index=c)
|
||||
tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred,
|
||||
target=target,
|
||||
class_index=c)
|
||||
|
||||
return tps, fps, tns, fns
|
||||
return tps, fps, tns, fns, sups
|
||||
|
||||
|
||||
def accuracy(
|
||||
|
@ -163,15 +165,16 @@ def accuracy(
|
|||
Return:
|
||||
A Tensor with the classification score.
|
||||
"""
|
||||
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred, target=target,
|
||||
num_classes=num_classes)
|
||||
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target,
|
||||
num_classes=num_classes)
|
||||
|
||||
if not (target > 0).any() and num_classes is None:
|
||||
raise RuntimeError("cannot infer num_classes when target is all zero")
|
||||
|
||||
accuracies = (tps + tns) / (tps + tns + fps + fns)
|
||||
|
||||
return reduce(accuracies, reduction=reduction)
|
||||
if reduction in ('elementwise_mean', 'sum'):
|
||||
return reduce(sum(tps) / sum(sups), reduction=reduction)
|
||||
if reduction == 'none':
|
||||
return reduce(tps / sups, reduction=reduction)
|
||||
|
||||
|
||||
def confusion_matrix(
|
||||
|
@ -193,13 +196,10 @@ def confusion_matrix(
|
|||
"""
|
||||
num_classes = get_num_classes(pred, target, None)
|
||||
|
||||
d = target.size(-1)
|
||||
batch_vec = torch.arange(target.size(-1))
|
||||
# this will account for multilabel
|
||||
unique_labels = batch_vec * num_classes ** 2 + target.view(-1) * num_classes + pred.view(-1)
|
||||
unique_labels = target.view(-1) * num_classes + pred.view(-1)
|
||||
|
||||
bins = torch.bincount(unique_labels, minlength=d * num_classes ** 2)
|
||||
cm = bins.reshape(d, num_classes, num_classes).squeeze().float()
|
||||
bins = torch.bincount(unique_labels, minlength=num_classes ** 2)
|
||||
cm = bins.reshape(num_classes, num_classes).squeeze().float()
|
||||
|
||||
if normalize:
|
||||
cm = cm / cm.sum(-1)
|
||||
|
@ -230,9 +230,9 @@ def precision_recall(
|
|||
Return:
|
||||
Tensor with precision and recall
|
||||
"""
|
||||
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred,
|
||||
target=target,
|
||||
num_classes=num_classes)
|
||||
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred,
|
||||
target=target,
|
||||
num_classes=num_classes)
|
||||
|
||||
tps = tps.to(torch.float)
|
||||
fps = fps.to(torch.float)
|
||||
|
@ -676,7 +676,7 @@ def dice_score(
|
|||
scores[i - bg] += no_fg_score
|
||||
continue
|
||||
|
||||
tp, fp, tn, fn = stat_scores(pred=pred, target=target, class_index=i)
|
||||
tp, fp, tn, fn, sup = stat_scores(pred=pred, target=target, class_index=i)
|
||||
|
||||
denom = (2 * tp + fp + fn).to(torch.float)
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
import torch
|
||||
from functools import partial
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.metrics.functional.classification import (
|
||||
|
@ -23,6 +24,33 @@ from pytorch_lightning.metrics.functional.classification import (
|
|||
auc,
|
||||
)
|
||||
|
||||
from sklearn.metrics import (
|
||||
accuracy_score as sk_accuracy,
|
||||
precision_score as sk_precision,
|
||||
recall_score as sk_recall,
|
||||
f1_score as sk_f1_score,
|
||||
fbeta_score as sk_fbeta_score,
|
||||
confusion_matrix as sk_confusion_matrix
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
|
||||
(sk_accuracy, accuracy),
|
||||
(partial(sk_precision, average='macro'), precision),
|
||||
(partial(sk_recall, average='macro'), recall),
|
||||
(partial(sk_f1_score, average='macro'), f1_score),
|
||||
(partial(sk_fbeta_score, average='macro', beta=2), partial(fbeta_score, beta=2)),
|
||||
(sk_confusion_matrix, confusion_matrix)
|
||||
])
|
||||
def test_against_sklearn(sklearn_metric, torch_metric):
|
||||
"""Compare PL metrics to sklearn version."""
|
||||
pred = torch.randint(10, (500,))
|
||||
target = torch.randint(10, (500,))
|
||||
|
||||
assert torch.allclose(
|
||||
torch.tensor(sklearn_metric(target, pred), dtype=torch.float),
|
||||
torch_metric(pred, target))
|
||||
|
||||
|
||||
def test_onehot():
|
||||
test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
||||
|
@ -113,32 +141,36 @@ def test_get_num_classes(pred, target, num_classes, expected_num_classes):
|
|||
assert get_num_classes(pred, target, num_classes) == expected_num_classes
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', 'expected_tn', 'expected_fn'], [
|
||||
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1),
|
||||
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1)
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp',
|
||||
'expected_tn', 'expected_fn', 'expected_support'], [
|
||||
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2),
|
||||
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2)
|
||||
])
|
||||
def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn):
|
||||
tp, fp, tn, fn = stat_scores(pred, target, class_index=4)
|
||||
def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
|
||||
tp, fp, tn, fn, sup = stat_scores(pred, target, class_index=4)
|
||||
|
||||
assert tp.item() == expected_tp
|
||||
assert fp.item() == expected_fp
|
||||
assert tn.item() == expected_tn
|
||||
assert fn.item() == expected_fn
|
||||
assert sup.item() == expected_support
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', 'expected_tn', 'expected_fn'], [
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp',
|
||||
'expected_tn', 'expected_fn', 'expected_support'], [
|
||||
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]),
|
||||
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1]),
|
||||
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]),
|
||||
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]),
|
||||
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1])
|
||||
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2])
|
||||
])
|
||||
def test_stat_scores_multiclass(pred, target, expected_tp, expected_fp, expected_tn, expected_fn):
|
||||
tp, fp, tn, fn = stat_scores_multiple_classes(pred, target)
|
||||
def test_stat_scores_multiclass(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
|
||||
tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target)
|
||||
|
||||
assert torch.allclose(torch.tensor(expected_tp).to(tp), tp)
|
||||
assert torch.allclose(torch.tensor(expected_fp).to(fp), fp)
|
||||
assert torch.allclose(torch.tensor(expected_tn).to(tn), tn)
|
||||
assert torch.allclose(torch.tensor(expected_fn).to(fn), fn)
|
||||
assert torch.allclose(torch.tensor(expected_support).to(sup), sup)
|
||||
|
||||
|
||||
def test_multilabel_accuracy():
|
||||
|
@ -146,7 +178,7 @@ def test_multilabel_accuracy():
|
|||
y1 = torch.tensor([[0, 1, 1], [1, 0, 1]])
|
||||
y2 = torch.tensor([[0, 0, 1], [1, 0, 1]])
|
||||
|
||||
assert torch.allclose(accuracy(y1, y2, reduction='none'), torch.tensor([0.8333333134651184] * 2))
|
||||
assert torch.allclose(accuracy(y1, y2, reduction='none'), torch.tensor([2 / 3, 1.]))
|
||||
assert torch.allclose(accuracy(y1, y1, reduction='none'), torch.tensor([1., 1.]))
|
||||
assert torch.allclose(accuracy(y2, y2, reduction='none'), torch.tensor([1., 1.]))
|
||||
assert torch.allclose(accuracy(y2, torch.logical_not(y2), reduction='none'), torch.tensor([0., 0.]))
|
||||
|
@ -156,6 +188,20 @@ def test_multilabel_accuracy():
|
|||
accuracy(y2, torch.zeros_like(y2), reduction='none')
|
||||
|
||||
|
||||
def test_accuracy():
|
||||
pred = torch.tensor([0, 1, 2, 3])
|
||||
target = torch.tensor([0, 1, 2, 2])
|
||||
acc = accuracy(pred, target)
|
||||
|
||||
assert acc.item() == 0.75
|
||||
|
||||
pred = torch.tensor([0, 1, 2, 2])
|
||||
target = torch.tensor([0, 1, 1, 3])
|
||||
acc = accuracy(pred, target)
|
||||
|
||||
assert acc.item() == 0.50
|
||||
|
||||
|
||||
def test_confusion_matrix():
|
||||
target = (torch.arange(120) % 3).view(-1, 1)
|
||||
pred = target.clone()
|
||||
|
|
Loading…
Reference in New Issue