lightning/tests/metrics/classification/test_average_precision.py

98 lines
3.8 KiB
Python
Raw Normal View History

from functools import partial
import numpy as np
import pytest
import torch
2021-02-06 16:41:40 +00:00
from sklearn.metrics import average_precision_score as sk_average_precision_score
from pytorch_lightning.metrics.classification.average_precision import AveragePrecision
from pytorch_lightning.metrics.functional.average_precision import average_precision
2021-02-06 16:41:40 +00:00
from tests.metrics.classification.inputs import _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.metrics.utils import MetricTester, NUM_CLASSES
torch.manual_seed(42)
2021-02-06 16:41:40 +00:00
def _sk_average_precision_score(y_true, probas_pred, num_classes=1):
if num_classes == 1:
2021-02-06 16:41:40 +00:00
return sk_average_precision_score(y_true, probas_pred)
res = []
for i in range(num_classes):
y_true_temp = np.zeros_like(y_true)
y_true_temp[y_true == i] = 1
2021-02-06 16:41:40 +00:00
res.append(sk_average_precision_score(y_true_temp, probas_pred[:, i]))
return res
2021-02-06 16:41:40 +00:00
def _sk_avg_prec_binary_prob(preds, target, num_classes=1):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
2021-02-06 16:41:40 +00:00
return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
2021-02-06 16:41:40 +00:00
def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1):
sk_preds = preds.reshape(-1, num_classes).numpy()
sk_target = target.view(-1).numpy()
2021-02-06 16:41:40 +00:00
return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
2021-02-06 16:41:40 +00:00
def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1):
sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy()
sk_target = target.view(-1).numpy()
2021-02-06 16:41:40 +00:00
return _sk_average_precision_score(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes)
@pytest.mark.parametrize(
"preds, target, sk_metric, num_classes", [
(_input_binary_prob.preds, _input_binary_prob.target, _sk_avg_prec_binary_prob, 1),
(_input_mcls_prob.preds, _input_mcls_prob.target, _sk_avg_prec_multiclass_prob, NUM_CLASSES),
(_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_avg_prec_multidim_multiclass_prob, NUM_CLASSES),
]
)
class TestAveragePrecision(MetricTester):
2021-02-06 16:41:40 +00:00
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_average_precision(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=AveragePrecision,
sk_metric=partial(sk_metric, num_classes=num_classes),
dist_sync_on_step=dist_sync_on_step,
metric_args={"num_classes": num_classes}
)
def test_average_precision_functional(self, preds, target, sk_metric, num_classes):
self.run_functional_metric_test(
preds,
target,
metric_functional=average_precision,
sk_metric=partial(sk_metric, num_classes=num_classes),
metric_args={"num_classes": num_classes},
)
2021-02-06 16:41:40 +00:00
@pytest.mark.parametrize(
['scores', 'target', 'expected_score'],
[
# Check the average_precision_score of a constant predictor is
# the TPR
# Generate a dataset with 25% of positives
# And a constant score
# The precision is then the fraction of positive whatever the recall
# is, as there is only one threshold:
pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25),
# With threshold 0.8 : 1 TP and 2 TN and one FN
pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75),
]
)
def test_average_precision(scores, target, expected_score):
assert average_precision(scores, target) == expected_score