Fix for accuracy calculation (#2183)

* accuracy_fix

* fix line length

* Apply suggestions from code review

* Update test_classification.py

Co-authored-by: Nicki Skafte <nugginea@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: William Falcon <waf2107@columbia.edu>
This commit is contained in:
Nicki Skafte 2020-06-15 00:14:29 +02:00 committed by GitHub
parent c0903b800d
commit 02262d0a93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 31 deletions

View File

@ -96,8 +96,9 @@ def stat_scores(
fp = ((pred == class_index) * (target != class_index)).to(torch.long).sum() fp = ((pred == class_index) * (target != class_index)).to(torch.long).sum()
tn = ((pred != class_index) * (target != class_index)).to(torch.long).sum() tn = ((pred != class_index) * (target != class_index)).to(torch.long).sum()
fn = ((pred != class_index) * (target == class_index)).to(torch.long).sum() fn = ((pred != class_index) * (target == class_index)).to(torch.long).sum()
sup = (target == class_index).to(torch.long).sum()
return tp, fp, tn, fn return tp, fp, tn, fn, sup
def stat_scores_multiple_classes( def stat_scores_multiple_classes(
@ -132,12 +133,13 @@ def stat_scores_multiple_classes(
fps = torch.zeros((num_classes,), device=pred.device) fps = torch.zeros((num_classes,), device=pred.device)
tns = torch.zeros((num_classes,), device=pred.device) tns = torch.zeros((num_classes,), device=pred.device)
fns = torch.zeros((num_classes,), device=pred.device) fns = torch.zeros((num_classes,), device=pred.device)
sups = torch.zeros((num_classes,), device=pred.device)
for c in range(num_classes): for c in range(num_classes):
tps[c], fps[c], tns[c], fns[c] = stat_scores(pred=pred, target=target, tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred,
class_index=c) target=target,
class_index=c)
return tps, fps, tns, fns return tps, fps, tns, fns, sups
def accuracy( def accuracy(
@ -163,15 +165,16 @@ def accuracy(
Return: Return:
A Tensor with the classification score. A Tensor with the classification score.
""" """
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred, target=target, tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target,
num_classes=num_classes) num_classes=num_classes)
if not (target > 0).any() and num_classes is None: if not (target > 0).any() and num_classes is None:
raise RuntimeError("cannot infer num_classes when target is all zero") raise RuntimeError("cannot infer num_classes when target is all zero")
accuracies = (tps + tns) / (tps + tns + fps + fns) if reduction in ('elementwise_mean', 'sum'):
return reduce(sum(tps) / sum(sups), reduction=reduction)
return reduce(accuracies, reduction=reduction) if reduction == 'none':
return reduce(tps / sups, reduction=reduction)
def confusion_matrix( def confusion_matrix(
@ -193,13 +196,10 @@ def confusion_matrix(
""" """
num_classes = get_num_classes(pred, target, None) num_classes = get_num_classes(pred, target, None)
d = target.size(-1) unique_labels = target.view(-1) * num_classes + pred.view(-1)
batch_vec = torch.arange(target.size(-1))
# this will account for multilabel
unique_labels = batch_vec * num_classes ** 2 + target.view(-1) * num_classes + pred.view(-1)
bins = torch.bincount(unique_labels, minlength=d * num_classes ** 2) bins = torch.bincount(unique_labels, minlength=num_classes ** 2)
cm = bins.reshape(d, num_classes, num_classes).squeeze().float() cm = bins.reshape(num_classes, num_classes).squeeze().float()
if normalize: if normalize:
cm = cm / cm.sum(-1) cm = cm / cm.sum(-1)
@ -230,9 +230,9 @@ def precision_recall(
Return: Return:
Tensor with precision and recall Tensor with precision and recall
""" """
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred, tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred,
target=target, target=target,
num_classes=num_classes) num_classes=num_classes)
tps = tps.to(torch.float) tps = tps.to(torch.float)
fps = fps.to(torch.float) fps = fps.to(torch.float)
@ -676,7 +676,7 @@ def dice_score(
scores[i - bg] += no_fg_score scores[i - bg] += no_fg_score
continue continue
tp, fp, tn, fn = stat_scores(pred=pred, target=target, class_index=i) tp, fp, tn, fn, sup = stat_scores(pred=pred, target=target, class_index=i)
denom = (2 * tp + fp + fn).to(torch.float) denom = (2 * tp + fp + fn).to(torch.float)

View File

@ -1,5 +1,6 @@
import pytest import pytest
import torch import torch
from functools import partial
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from pytorch_lightning.metrics.functional.classification import ( from pytorch_lightning.metrics.functional.classification import (
@ -23,6 +24,33 @@ from pytorch_lightning.metrics.functional.classification import (
auc, auc,
) )
from sklearn.metrics import (
accuracy_score as sk_accuracy,
precision_score as sk_precision,
recall_score as sk_recall,
f1_score as sk_f1_score,
fbeta_score as sk_fbeta_score,
confusion_matrix as sk_confusion_matrix
)
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
(sk_accuracy, accuracy),
(partial(sk_precision, average='macro'), precision),
(partial(sk_recall, average='macro'), recall),
(partial(sk_f1_score, average='macro'), f1_score),
(partial(sk_fbeta_score, average='macro', beta=2), partial(fbeta_score, beta=2)),
(sk_confusion_matrix, confusion_matrix)
])
def test_against_sklearn(sklearn_metric, torch_metric):
"""Compare PL metrics to sklearn version."""
pred = torch.randint(10, (500,))
target = torch.randint(10, (500,))
assert torch.allclose(
torch.tensor(sklearn_metric(target, pred), dtype=torch.float),
torch_metric(pred, target))
def test_onehot(): def test_onehot():
test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
@ -113,32 +141,36 @@ def test_get_num_classes(pred, target, num_classes, expected_num_classes):
assert get_num_classes(pred, target, num_classes) == expected_num_classes assert get_num_classes(pred, target, num_classes) == expected_num_classes
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', 'expected_tn', 'expected_fn'], [ @pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp',
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1), 'expected_tn', 'expected_fn', 'expected_support'], [
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1) pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2),
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2)
]) ])
def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn): def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
tp, fp, tn, fn = stat_scores(pred, target, class_index=4) tp, fp, tn, fn, sup = stat_scores(pred, target, class_index=4)
assert tp.item() == expected_tp assert tp.item() == expected_tp
assert fp.item() == expected_fp assert fp.item() == expected_fp
assert tn.item() == expected_tn assert tn.item() == expected_tn
assert fn.item() == expected_fn assert fn.item() == expected_fn
assert sup.item() == expected_support
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', 'expected_tn', 'expected_fn'], [ @pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp',
'expected_tn', 'expected_fn', 'expected_support'], [
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]),
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1]), [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]),
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]),
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1]) [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2])
]) ])
def test_stat_scores_multiclass(pred, target, expected_tp, expected_fp, expected_tn, expected_fn): def test_stat_scores_multiclass(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
tp, fp, tn, fn = stat_scores_multiple_classes(pred, target) tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target)
assert torch.allclose(torch.tensor(expected_tp).to(tp), tp) assert torch.allclose(torch.tensor(expected_tp).to(tp), tp)
assert torch.allclose(torch.tensor(expected_fp).to(fp), fp) assert torch.allclose(torch.tensor(expected_fp).to(fp), fp)
assert torch.allclose(torch.tensor(expected_tn).to(tn), tn) assert torch.allclose(torch.tensor(expected_tn).to(tn), tn)
assert torch.allclose(torch.tensor(expected_fn).to(fn), fn) assert torch.allclose(torch.tensor(expected_fn).to(fn), fn)
assert torch.allclose(torch.tensor(expected_support).to(sup), sup)
def test_multilabel_accuracy(): def test_multilabel_accuracy():
@ -146,7 +178,7 @@ def test_multilabel_accuracy():
y1 = torch.tensor([[0, 1, 1], [1, 0, 1]]) y1 = torch.tensor([[0, 1, 1], [1, 0, 1]])
y2 = torch.tensor([[0, 0, 1], [1, 0, 1]]) y2 = torch.tensor([[0, 0, 1], [1, 0, 1]])
assert torch.allclose(accuracy(y1, y2, reduction='none'), torch.tensor([0.8333333134651184] * 2)) assert torch.allclose(accuracy(y1, y2, reduction='none'), torch.tensor([2 / 3, 1.]))
assert torch.allclose(accuracy(y1, y1, reduction='none'), torch.tensor([1., 1.])) assert torch.allclose(accuracy(y1, y1, reduction='none'), torch.tensor([1., 1.]))
assert torch.allclose(accuracy(y2, y2, reduction='none'), torch.tensor([1., 1.])) assert torch.allclose(accuracy(y2, y2, reduction='none'), torch.tensor([1., 1.]))
assert torch.allclose(accuracy(y2, torch.logical_not(y2), reduction='none'), torch.tensor([0., 0.])) assert torch.allclose(accuracy(y2, torch.logical_not(y2), reduction='none'), torch.tensor([0., 0.]))
@ -156,6 +188,20 @@ def test_multilabel_accuracy():
accuracy(y2, torch.zeros_like(y2), reduction='none') accuracy(y2, torch.zeros_like(y2), reduction='none')
def test_accuracy():
pred = torch.tensor([0, 1, 2, 3])
target = torch.tensor([0, 1, 2, 2])
acc = accuracy(pred, target)
assert acc.item() == 0.75
pred = torch.tensor([0, 1, 2, 2])
target = torch.tensor([0, 1, 1, 3])
acc = accuracy(pred, target)
assert acc.item() == 0.50
def test_confusion_matrix(): def test_confusion_matrix():
target = (torch.arange(120) % 3).view(-1, 1) target = (torch.arange(120) % 3).view(-1, 1)
pred = target.clone() pred = target.clone()