From 92f122e0df7e233f3a8b7873c7294155afbbf852 Mon Sep 17 00:00:00 2001 From: elias-ramzi <43226247+elias-ramzi@users.noreply.github.com> Date: Tue, 23 Jun 2020 13:21:00 +0200 Subject: [PATCH] Fix average_precision metric (#2319) * Fixed average_precision metric, parenthesis were missing. Added test test that failed with the old implementation * Modified CHANGELOG.md * Update CHANGELOG.md Co-authored-by: Jirka Borovec Co-authored-by: Jirka Borovec --- CHANGELOG.md | 2 ++ .../metrics/functional/classification.py | 2 +- tests/metrics/functional/test_classification.py | 13 +++++++------ 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 10d857f2a5..e3371a3db7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298)) +- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319)) + ## [0.8.1] - 2020-06-19 diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 8b392e5a11..c5d3613640 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -844,7 +844,7 @@ def average_precision( # Return the step function integral # The following works because the last entry of precision is # guaranteed to be 1, as returned by precision_recall_curve - return -torch.sum(recall[1:] - recall[:-1] * precision[:-1]) + return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) def dice_score( diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 4b748f5444..6a06ef4a4a 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -342,18 +342,19 @@ def test_auc(x, y, expected): assert auc(torch.tensor(x), torch.tensor(y)) == expected -def test_average_precision_constant_values(): +@pytest.mark.parametrize(['scores', 'target', 'expected_score'], [ # Check the average_precision_score of a constant predictor is # the TPR - # Generate a dataset with 25% of positives - target = torch.zeros(100, dtype=torch.float) - target[::4] = 1 # And a constant score - pred = torch.ones(100) # The precision is then the fraction of positive whatever the recall # is, as there is only one threshold: - assert average_precision(pred, target).item() == .25 + pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25), + # With treshold .8 : 1 TP and 2 TN and one FN + pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75), +]) +def test_average_precision(scores, target, expected_score): + assert average_precision(scores, target) == expected_score @pytest.mark.parametrize(['pred', 'target', 'expected'], [