Fix average_precision metric (#2319)
* Fixed average_precision metric, parenthesis were missing. Added test test that failed with the old implementation * Modified CHANGELOG.md * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
63bd0582e3
commit
92f122e0df
|
@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298))
|
||||
|
||||
- Fixed `average_precision` metric ([#2319](https://github.com/PyTorchLightning/pytorch-lightning/pull/2319))
|
||||
|
||||
|
||||
## [0.8.1] - 2020-06-19
|
||||
|
||||
|
|
|
@ -844,7 +844,7 @@ def average_precision(
|
|||
# Return the step function integral
|
||||
# The following works because the last entry of precision is
|
||||
# guaranteed to be 1, as returned by precision_recall_curve
|
||||
return -torch.sum(recall[1:] - recall[:-1] * precision[:-1])
|
||||
return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1])
|
||||
|
||||
|
||||
def dice_score(
|
||||
|
|
|
@ -342,18 +342,19 @@ def test_auc(x, y, expected):
|
|||
assert auc(torch.tensor(x), torch.tensor(y)) == expected
|
||||
|
||||
|
||||
def test_average_precision_constant_values():
|
||||
@pytest.mark.parametrize(['scores', 'target', 'expected_score'], [
|
||||
# Check the average_precision_score of a constant predictor is
|
||||
# the TPR
|
||||
|
||||
# Generate a dataset with 25% of positives
|
||||
target = torch.zeros(100, dtype=torch.float)
|
||||
target[::4] = 1
|
||||
# And a constant score
|
||||
pred = torch.ones(100)
|
||||
# The precision is then the fraction of positive whatever the recall
|
||||
# is, as there is only one threshold:
|
||||
assert average_precision(pred, target).item() == .25
|
||||
pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25),
|
||||
# With treshold .8 : 1 TP and 2 TN and one FN
|
||||
pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75),
|
||||
])
|
||||
def test_average_precision(scores, target, expected_score):
|
||||
assert average_precision(scores, target) == expected_score
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||
|
|
Loading…
Reference in New Issue