Fix for accuracy calculation (#2183)
* accuracy_fix * fix line length * Apply suggestions from code review * Update test_classification.py Co-authored-by: Nicki Skafte <nugginea@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
parent
c0903b800d
commit
02262d0a93
|
@ -96,8 +96,9 @@ def stat_scores(
|
||||||
fp = ((pred == class_index) * (target != class_index)).to(torch.long).sum()
|
fp = ((pred == class_index) * (target != class_index)).to(torch.long).sum()
|
||||||
tn = ((pred != class_index) * (target != class_index)).to(torch.long).sum()
|
tn = ((pred != class_index) * (target != class_index)).to(torch.long).sum()
|
||||||
fn = ((pred != class_index) * (target == class_index)).to(torch.long).sum()
|
fn = ((pred != class_index) * (target == class_index)).to(torch.long).sum()
|
||||||
|
sup = (target == class_index).to(torch.long).sum()
|
||||||
|
|
||||||
return tp, fp, tn, fn
|
return tp, fp, tn, fn, sup
|
||||||
|
|
||||||
|
|
||||||
def stat_scores_multiple_classes(
|
def stat_scores_multiple_classes(
|
||||||
|
@ -132,12 +133,13 @@ def stat_scores_multiple_classes(
|
||||||
fps = torch.zeros((num_classes,), device=pred.device)
|
fps = torch.zeros((num_classes,), device=pred.device)
|
||||||
tns = torch.zeros((num_classes,), device=pred.device)
|
tns = torch.zeros((num_classes,), device=pred.device)
|
||||||
fns = torch.zeros((num_classes,), device=pred.device)
|
fns = torch.zeros((num_classes,), device=pred.device)
|
||||||
|
sups = torch.zeros((num_classes,), device=pred.device)
|
||||||
for c in range(num_classes):
|
for c in range(num_classes):
|
||||||
tps[c], fps[c], tns[c], fns[c] = stat_scores(pred=pred, target=target,
|
tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred,
|
||||||
class_index=c)
|
target=target,
|
||||||
|
class_index=c)
|
||||||
|
|
||||||
return tps, fps, tns, fns
|
return tps, fps, tns, fns, sups
|
||||||
|
|
||||||
|
|
||||||
def accuracy(
|
def accuracy(
|
||||||
|
@ -163,15 +165,16 @@ def accuracy(
|
||||||
Return:
|
Return:
|
||||||
A Tensor with the classification score.
|
A Tensor with the classification score.
|
||||||
"""
|
"""
|
||||||
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred, target=target,
|
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target,
|
||||||
num_classes=num_classes)
|
num_classes=num_classes)
|
||||||
|
|
||||||
if not (target > 0).any() and num_classes is None:
|
if not (target > 0).any() and num_classes is None:
|
||||||
raise RuntimeError("cannot infer num_classes when target is all zero")
|
raise RuntimeError("cannot infer num_classes when target is all zero")
|
||||||
|
|
||||||
accuracies = (tps + tns) / (tps + tns + fps + fns)
|
if reduction in ('elementwise_mean', 'sum'):
|
||||||
|
return reduce(sum(tps) / sum(sups), reduction=reduction)
|
||||||
return reduce(accuracies, reduction=reduction)
|
if reduction == 'none':
|
||||||
|
return reduce(tps / sups, reduction=reduction)
|
||||||
|
|
||||||
|
|
||||||
def confusion_matrix(
|
def confusion_matrix(
|
||||||
|
@ -193,13 +196,10 @@ def confusion_matrix(
|
||||||
"""
|
"""
|
||||||
num_classes = get_num_classes(pred, target, None)
|
num_classes = get_num_classes(pred, target, None)
|
||||||
|
|
||||||
d = target.size(-1)
|
unique_labels = target.view(-1) * num_classes + pred.view(-1)
|
||||||
batch_vec = torch.arange(target.size(-1))
|
|
||||||
# this will account for multilabel
|
|
||||||
unique_labels = batch_vec * num_classes ** 2 + target.view(-1) * num_classes + pred.view(-1)
|
|
||||||
|
|
||||||
bins = torch.bincount(unique_labels, minlength=d * num_classes ** 2)
|
bins = torch.bincount(unique_labels, minlength=num_classes ** 2)
|
||||||
cm = bins.reshape(d, num_classes, num_classes).squeeze().float()
|
cm = bins.reshape(num_classes, num_classes).squeeze().float()
|
||||||
|
|
||||||
if normalize:
|
if normalize:
|
||||||
cm = cm / cm.sum(-1)
|
cm = cm / cm.sum(-1)
|
||||||
|
@ -230,9 +230,9 @@ def precision_recall(
|
||||||
Return:
|
Return:
|
||||||
Tensor with precision and recall
|
Tensor with precision and recall
|
||||||
"""
|
"""
|
||||||
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred,
|
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred,
|
||||||
target=target,
|
target=target,
|
||||||
num_classes=num_classes)
|
num_classes=num_classes)
|
||||||
|
|
||||||
tps = tps.to(torch.float)
|
tps = tps.to(torch.float)
|
||||||
fps = fps.to(torch.float)
|
fps = fps.to(torch.float)
|
||||||
|
@ -676,7 +676,7 @@ def dice_score(
|
||||||
scores[i - bg] += no_fg_score
|
scores[i - bg] += no_fg_score
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tp, fp, tn, fn = stat_scores(pred=pred, target=target, class_index=i)
|
tp, fp, tn, fn, sup = stat_scores(pred=pred, target=target, class_index=i)
|
||||||
|
|
||||||
denom = (2 * tp + fp + fn).to(torch.float)
|
denom = (2 * tp + fp + fn).to(torch.float)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from pytorch_lightning.metrics.functional.classification import (
|
from pytorch_lightning.metrics.functional.classification import (
|
||||||
|
@ -23,6 +24,33 @@ from pytorch_lightning.metrics.functional.classification import (
|
||||||
auc,
|
auc,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from sklearn.metrics import (
|
||||||
|
accuracy_score as sk_accuracy,
|
||||||
|
precision_score as sk_precision,
|
||||||
|
recall_score as sk_recall,
|
||||||
|
f1_score as sk_f1_score,
|
||||||
|
fbeta_score as sk_fbeta_score,
|
||||||
|
confusion_matrix as sk_confusion_matrix
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
|
||||||
|
(sk_accuracy, accuracy),
|
||||||
|
(partial(sk_precision, average='macro'), precision),
|
||||||
|
(partial(sk_recall, average='macro'), recall),
|
||||||
|
(partial(sk_f1_score, average='macro'), f1_score),
|
||||||
|
(partial(sk_fbeta_score, average='macro', beta=2), partial(fbeta_score, beta=2)),
|
||||||
|
(sk_confusion_matrix, confusion_matrix)
|
||||||
|
])
|
||||||
|
def test_against_sklearn(sklearn_metric, torch_metric):
|
||||||
|
"""Compare PL metrics to sklearn version."""
|
||||||
|
pred = torch.randint(10, (500,))
|
||||||
|
target = torch.randint(10, (500,))
|
||||||
|
|
||||||
|
assert torch.allclose(
|
||||||
|
torch.tensor(sklearn_metric(target, pred), dtype=torch.float),
|
||||||
|
torch_metric(pred, target))
|
||||||
|
|
||||||
|
|
||||||
def test_onehot():
|
def test_onehot():
|
||||||
test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
||||||
|
@ -113,32 +141,36 @@ def test_get_num_classes(pred, target, num_classes, expected_num_classes):
|
||||||
assert get_num_classes(pred, target, num_classes) == expected_num_classes
|
assert get_num_classes(pred, target, num_classes) == expected_num_classes
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', 'expected_tn', 'expected_fn'], [
|
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp',
|
||||||
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1),
|
'expected_tn', 'expected_fn', 'expected_support'], [
|
||||||
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1)
|
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2),
|
||||||
|
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2)
|
||||||
])
|
])
|
||||||
def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn):
|
def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
|
||||||
tp, fp, tn, fn = stat_scores(pred, target, class_index=4)
|
tp, fp, tn, fn, sup = stat_scores(pred, target, class_index=4)
|
||||||
|
|
||||||
assert tp.item() == expected_tp
|
assert tp.item() == expected_tp
|
||||||
assert fp.item() == expected_fp
|
assert fp.item() == expected_fp
|
||||||
assert tn.item() == expected_tn
|
assert tn.item() == expected_tn
|
||||||
assert fn.item() == expected_fn
|
assert fn.item() == expected_fn
|
||||||
|
assert sup.item() == expected_support
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', 'expected_tn', 'expected_fn'], [
|
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp',
|
||||||
|
'expected_tn', 'expected_fn', 'expected_support'], [
|
||||||
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]),
|
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]),
|
||||||
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1]),
|
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]),
|
||||||
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]),
|
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]),
|
||||||
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1])
|
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2])
|
||||||
])
|
])
|
||||||
def test_stat_scores_multiclass(pred, target, expected_tp, expected_fp, expected_tn, expected_fn):
|
def test_stat_scores_multiclass(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
|
||||||
tp, fp, tn, fn = stat_scores_multiple_classes(pred, target)
|
tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target)
|
||||||
|
|
||||||
assert torch.allclose(torch.tensor(expected_tp).to(tp), tp)
|
assert torch.allclose(torch.tensor(expected_tp).to(tp), tp)
|
||||||
assert torch.allclose(torch.tensor(expected_fp).to(fp), fp)
|
assert torch.allclose(torch.tensor(expected_fp).to(fp), fp)
|
||||||
assert torch.allclose(torch.tensor(expected_tn).to(tn), tn)
|
assert torch.allclose(torch.tensor(expected_tn).to(tn), tn)
|
||||||
assert torch.allclose(torch.tensor(expected_fn).to(fn), fn)
|
assert torch.allclose(torch.tensor(expected_fn).to(fn), fn)
|
||||||
|
assert torch.allclose(torch.tensor(expected_support).to(sup), sup)
|
||||||
|
|
||||||
|
|
||||||
def test_multilabel_accuracy():
|
def test_multilabel_accuracy():
|
||||||
|
@ -146,7 +178,7 @@ def test_multilabel_accuracy():
|
||||||
y1 = torch.tensor([[0, 1, 1], [1, 0, 1]])
|
y1 = torch.tensor([[0, 1, 1], [1, 0, 1]])
|
||||||
y2 = torch.tensor([[0, 0, 1], [1, 0, 1]])
|
y2 = torch.tensor([[0, 0, 1], [1, 0, 1]])
|
||||||
|
|
||||||
assert torch.allclose(accuracy(y1, y2, reduction='none'), torch.tensor([0.8333333134651184] * 2))
|
assert torch.allclose(accuracy(y1, y2, reduction='none'), torch.tensor([2 / 3, 1.]))
|
||||||
assert torch.allclose(accuracy(y1, y1, reduction='none'), torch.tensor([1., 1.]))
|
assert torch.allclose(accuracy(y1, y1, reduction='none'), torch.tensor([1., 1.]))
|
||||||
assert torch.allclose(accuracy(y2, y2, reduction='none'), torch.tensor([1., 1.]))
|
assert torch.allclose(accuracy(y2, y2, reduction='none'), torch.tensor([1., 1.]))
|
||||||
assert torch.allclose(accuracy(y2, torch.logical_not(y2), reduction='none'), torch.tensor([0., 0.]))
|
assert torch.allclose(accuracy(y2, torch.logical_not(y2), reduction='none'), torch.tensor([0., 0.]))
|
||||||
|
@ -156,6 +188,20 @@ def test_multilabel_accuracy():
|
||||||
accuracy(y2, torch.zeros_like(y2), reduction='none')
|
accuracy(y2, torch.zeros_like(y2), reduction='none')
|
||||||
|
|
||||||
|
|
||||||
|
def test_accuracy():
|
||||||
|
pred = torch.tensor([0, 1, 2, 3])
|
||||||
|
target = torch.tensor([0, 1, 2, 2])
|
||||||
|
acc = accuracy(pred, target)
|
||||||
|
|
||||||
|
assert acc.item() == 0.75
|
||||||
|
|
||||||
|
pred = torch.tensor([0, 1, 2, 2])
|
||||||
|
target = torch.tensor([0, 1, 1, 3])
|
||||||
|
acc = accuracy(pred, target)
|
||||||
|
|
||||||
|
assert acc.item() == 0.50
|
||||||
|
|
||||||
|
|
||||||
def test_confusion_matrix():
|
def test_confusion_matrix():
|
||||||
target = (torch.arange(120) % 3).view(-1, 1)
|
target = (torch.arange(120) % 3).view(-1, 1)
|
||||||
pred = target.clone()
|
pred = target.clone()
|
||||||
|
|
Loading…
Reference in New Issue