From 02262d0a93e7b050b5ebfb643056705a3a059ef8 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 15 Jun 2020 00:14:29 +0200 Subject: [PATCH] Fix for accuracy calculation (#2183) * accuracy_fix * fix line length * Apply suggestions from code review * Update test_classification.py Co-authored-by: Nicki Skafte Co-authored-by: Jirka Borovec Co-authored-by: William Falcon --- .../metrics/functional/classification.py | 40 +++++------ .../metrics/functional/test_classification.py | 68 ++++++++++++++++--- 2 files changed, 77 insertions(+), 31 deletions(-) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 6c792af8d7..64d821fbb3 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -96,8 +96,9 @@ def stat_scores( fp = ((pred == class_index) * (target != class_index)).to(torch.long).sum() tn = ((pred != class_index) * (target != class_index)).to(torch.long).sum() fn = ((pred != class_index) * (target == class_index)).to(torch.long).sum() + sup = (target == class_index).to(torch.long).sum() - return tp, fp, tn, fn + return tp, fp, tn, fn, sup def stat_scores_multiple_classes( @@ -132,12 +133,13 @@ def stat_scores_multiple_classes( fps = torch.zeros((num_classes,), device=pred.device) tns = torch.zeros((num_classes,), device=pred.device) fns = torch.zeros((num_classes,), device=pred.device) - + sups = torch.zeros((num_classes,), device=pred.device) for c in range(num_classes): - tps[c], fps[c], tns[c], fns[c] = stat_scores(pred=pred, target=target, - class_index=c) + tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred, + target=target, + class_index=c) - return tps, fps, tns, fns + return tps, fps, tns, fns, sups def accuracy( @@ -163,15 +165,16 @@ def accuracy( Return: A Tensor with the classification score. """ - tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred, target=target, - num_classes=num_classes) + tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target, + num_classes=num_classes) if not (target > 0).any() and num_classes is None: raise RuntimeError("cannot infer num_classes when target is all zero") - accuracies = (tps + tns) / (tps + tns + fps + fns) - - return reduce(accuracies, reduction=reduction) + if reduction in ('elementwise_mean', 'sum'): + return reduce(sum(tps) / sum(sups), reduction=reduction) + if reduction == 'none': + return reduce(tps / sups, reduction=reduction) def confusion_matrix( @@ -193,13 +196,10 @@ def confusion_matrix( """ num_classes = get_num_classes(pred, target, None) - d = target.size(-1) - batch_vec = torch.arange(target.size(-1)) - # this will account for multilabel - unique_labels = batch_vec * num_classes ** 2 + target.view(-1) * num_classes + pred.view(-1) + unique_labels = target.view(-1) * num_classes + pred.view(-1) - bins = torch.bincount(unique_labels, minlength=d * num_classes ** 2) - cm = bins.reshape(d, num_classes, num_classes).squeeze().float() + bins = torch.bincount(unique_labels, minlength=num_classes ** 2) + cm = bins.reshape(num_classes, num_classes).squeeze().float() if normalize: cm = cm / cm.sum(-1) @@ -230,9 +230,9 @@ def precision_recall( Return: Tensor with precision and recall """ - tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred, - target=target, - num_classes=num_classes) + tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, + target=target, + num_classes=num_classes) tps = tps.to(torch.float) fps = fps.to(torch.float) @@ -676,7 +676,7 @@ def dice_score( scores[i - bg] += no_fg_score continue - tp, fp, tn, fn = stat_scores(pred=pred, target=target, class_index=i) + tp, fp, tn, fn, sup = stat_scores(pred=pred, target=target, class_index=i) denom = (2 * tp + fp + fn).to(torch.float) diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index a55a049a9a..01dbdd6dc6 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -1,5 +1,6 @@ import pytest import torch +from functools import partial from pytorch_lightning import seed_everything from pytorch_lightning.metrics.functional.classification import ( @@ -23,6 +24,33 @@ from pytorch_lightning.metrics.functional.classification import ( auc, ) +from sklearn.metrics import ( + accuracy_score as sk_accuracy, + precision_score as sk_precision, + recall_score as sk_recall, + f1_score as sk_f1_score, + fbeta_score as sk_fbeta_score, + confusion_matrix as sk_confusion_matrix +) + + +@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ + (sk_accuracy, accuracy), + (partial(sk_precision, average='macro'), precision), + (partial(sk_recall, average='macro'), recall), + (partial(sk_f1_score, average='macro'), f1_score), + (partial(sk_fbeta_score, average='macro', beta=2), partial(fbeta_score, beta=2)), + (sk_confusion_matrix, confusion_matrix) +]) +def test_against_sklearn(sklearn_metric, torch_metric): + """Compare PL metrics to sklearn version.""" + pred = torch.randint(10, (500,)) + target = torch.randint(10, (500,)) + + assert torch.allclose( + torch.tensor(sklearn_metric(target, pred), dtype=torch.float), + torch_metric(pred, target)) + def test_onehot(): test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) @@ -113,32 +141,36 @@ def test_get_num_classes(pred, target, num_classes, expected_num_classes): assert get_num_classes(pred, target, num_classes) == expected_num_classes -@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', 'expected_tn', 'expected_fn'], [ - pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1), - pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1) +@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', + 'expected_tn', 'expected_fn', 'expected_support'], [ + pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2), + pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2) ]) -def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn): - tp, fp, tn, fn = stat_scores(pred, target, class_index=4) +def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support): + tp, fp, tn, fn, sup = stat_scores(pred, target, class_index=4) assert tp.item() == expected_tp assert fp.item() == expected_fp assert tn.item() == expected_tn assert fn.item() == expected_fn + assert sup.item() == expected_support -@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', 'expected_tn', 'expected_fn'], [ +@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', + 'expected_tn', 'expected_fn', 'expected_support'], [ pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), - [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1]), + [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]), pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), - [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1]) + [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]) ]) -def test_stat_scores_multiclass(pred, target, expected_tp, expected_fp, expected_tn, expected_fn): - tp, fp, tn, fn = stat_scores_multiple_classes(pred, target) +def test_stat_scores_multiclass(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support): + tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target) assert torch.allclose(torch.tensor(expected_tp).to(tp), tp) assert torch.allclose(torch.tensor(expected_fp).to(fp), fp) assert torch.allclose(torch.tensor(expected_tn).to(tn), tn) assert torch.allclose(torch.tensor(expected_fn).to(fn), fn) + assert torch.allclose(torch.tensor(expected_support).to(sup), sup) def test_multilabel_accuracy(): @@ -146,7 +178,7 @@ def test_multilabel_accuracy(): y1 = torch.tensor([[0, 1, 1], [1, 0, 1]]) y2 = torch.tensor([[0, 0, 1], [1, 0, 1]]) - assert torch.allclose(accuracy(y1, y2, reduction='none'), torch.tensor([0.8333333134651184] * 2)) + assert torch.allclose(accuracy(y1, y2, reduction='none'), torch.tensor([2 / 3, 1.])) assert torch.allclose(accuracy(y1, y1, reduction='none'), torch.tensor([1., 1.])) assert torch.allclose(accuracy(y2, y2, reduction='none'), torch.tensor([1., 1.])) assert torch.allclose(accuracy(y2, torch.logical_not(y2), reduction='none'), torch.tensor([0., 0.])) @@ -156,6 +188,20 @@ def test_multilabel_accuracy(): accuracy(y2, torch.zeros_like(y2), reduction='none') +def test_accuracy(): + pred = torch.tensor([0, 1, 2, 3]) + target = torch.tensor([0, 1, 2, 2]) + acc = accuracy(pred, target) + + assert acc.item() == 0.75 + + pred = torch.tensor([0, 1, 2, 2]) + target = torch.tensor([0, 1, 1, 3]) + acc = accuracy(pred, target) + + assert acc.item() == 0.50 + + def test_confusion_matrix(): target = (torch.arange(120) % 3).view(-1, 1) pred = target.clone()