2020-10-14 17:01:43 +00:00
|
|
|
from functools import partial
|
2021-01-18 08:24:13 +00:00
|
|
|
from typing import Callable, Optional
|
2020-10-14 17:01:43 +00:00
|
|
|
|
2020-10-14 17:51:58 +00:00
|
|
|
import numpy as np
|
|
|
|
import pytest
|
|
|
|
import torch
|
2020-10-14 17:01:43 +00:00
|
|
|
from sklearn.metrics import precision_score, recall_score
|
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
from pytorch_lightning.metrics import Metric, Precision, Recall
|
|
|
|
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
|
|
|
|
from pytorch_lightning.metrics.functional import precision, precision_recall, recall
|
2021-02-06 16:41:40 +00:00
|
|
|
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
|
|
|
|
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
|
|
|
|
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
|
|
|
|
from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc
|
|
|
|
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
|
|
|
|
from tests.metrics.classification.inputs import _input_multilabel as _input_mlb
|
|
|
|
from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob
|
2021-01-14 11:05:28 +00:00
|
|
|
from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD
|
2020-10-14 17:01:43 +00:00
|
|
|
|
|
|
|
torch.manual_seed(42)
|
|
|
|
|
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average=None):
|
|
|
|
if average == "none":
|
|
|
|
average = None
|
|
|
|
if num_classes == 1:
|
|
|
|
average = "binary"
|
|
|
|
|
|
|
|
labels = list(range(num_classes))
|
|
|
|
try:
|
|
|
|
labels.remove(ignore_index)
|
|
|
|
except ValueError:
|
|
|
|
pass
|
|
|
|
|
|
|
|
sk_preds, sk_target, _ = _input_format_classification(
|
|
|
|
preds, target, THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
|
|
|
|
)
|
|
|
|
sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()
|
|
|
|
|
|
|
|
sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=0, labels=labels)
|
|
|
|
|
|
|
|
if len(labels) != num_classes and not average:
|
|
|
|
sk_scores = np.insert(sk_scores, ignore_index, np.nan)
|
|
|
|
|
|
|
|
return sk_scores
|
|
|
|
|
2020-10-14 17:01:43 +00:00
|
|
|
|
2021-02-06 16:41:40 +00:00
|
|
|
def _sk_prec_recall_multidim_multiclass(
|
|
|
|
preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average
|
|
|
|
):
|
2021-01-18 08:24:13 +00:00
|
|
|
preds, target, _ = _input_format_classification(
|
|
|
|
preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
|
|
|
|
)
|
2020-10-14 17:01:43 +00:00
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
if mdmc_average == "global":
|
|
|
|
preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1])
|
|
|
|
target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1])
|
2020-10-14 17:01:43 +00:00
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
return _sk_prec_recall(preds, target, sk_fn, num_classes, average, False, ignore_index)
|
|
|
|
elif mdmc_average == "samplewise":
|
|
|
|
scores = []
|
2020-10-14 17:01:43 +00:00
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
for i in range(preds.shape[0]):
|
|
|
|
pred_i = preds[i, ...].T
|
|
|
|
target_i = target[i, ...].T
|
|
|
|
scores_i = _sk_prec_recall(pred_i, target_i, sk_fn, num_classes, average, False, ignore_index)
|
2020-10-14 17:01:43 +00:00
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
scores.append(np.expand_dims(scores_i, 0))
|
2020-10-14 17:01:43 +00:00
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
return np.concatenate(scores).mean(axis=0)
|
2020-10-14 17:01:43 +00:00
|
|
|
|
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
@pytest.mark.parametrize("metric, fn_metric", [(Precision, precision), (Recall, recall)])
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"average, mdmc_average, num_classes, ignore_index, match_str",
|
|
|
|
[
|
|
|
|
("wrong", None, None, None, "`average`"),
|
|
|
|
("micro", "wrong", None, None, "`mdmc"),
|
|
|
|
("macro", None, None, None, "number of classes"),
|
|
|
|
("macro", None, 1, 0, "ignore_index"),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ignore_index, match_str):
|
|
|
|
with pytest.raises(ValueError, match=match_str):
|
|
|
|
metric(
|
|
|
|
average=average,
|
|
|
|
mdmc_average=mdmc_average,
|
|
|
|
num_classes=num_classes,
|
|
|
|
ignore_index=ignore_index,
|
|
|
|
)
|
|
|
|
|
|
|
|
with pytest.raises(ValueError, match=match_str):
|
|
|
|
fn_metric(
|
2021-02-06 16:41:40 +00:00
|
|
|
_input_binary.preds[0],
|
|
|
|
_input_binary.target[0],
|
2021-01-18 08:24:13 +00:00
|
|
|
average=average,
|
|
|
|
mdmc_average=mdmc_average,
|
|
|
|
num_classes=num_classes,
|
|
|
|
ignore_index=ignore_index,
|
|
|
|
)
|
2020-10-14 17:01:43 +00:00
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
with pytest.raises(ValueError, match=match_str):
|
|
|
|
precision_recall(
|
2021-02-06 16:41:40 +00:00
|
|
|
_input_binary.preds[0],
|
|
|
|
_input_binary.target[0],
|
2021-01-18 08:24:13 +00:00
|
|
|
average=average,
|
|
|
|
mdmc_average=mdmc_average,
|
|
|
|
num_classes=num_classes,
|
|
|
|
ignore_index=ignore_index,
|
|
|
|
)
|
2020-10-14 17:01:43 +00:00
|
|
|
|
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)])
|
|
|
|
def test_zero_division(metric_class, metric_fn):
|
|
|
|
""" Test that zero_division works correctly (currently should just set to 0). """
|
2020-10-14 17:01:43 +00:00
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
preds = torch.tensor([1, 2, 1, 1])
|
|
|
|
target = torch.tensor([2, 1, 2, 1])
|
2020-10-14 17:01:43 +00:00
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
cl_metric = metric_class(average="none", num_classes=3)
|
|
|
|
cl_metric(preds, target)
|
2020-10-14 17:01:43 +00:00
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
result_cl = cl_metric.compute()
|
|
|
|
result_fn = metric_fn(preds, target, average="none", num_classes=3)
|
2020-10-14 17:01:43 +00:00
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
assert result_cl[0] == result_fn[0] == 0
|
2020-10-14 17:01:43 +00:00
|
|
|
|
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)])
|
|
|
|
def test_no_support(metric_class, metric_fn):
|
|
|
|
"""This tests a rare edge case, where there is only one class present
|
|
|
|
in target, and ignore_index is set to exactly that class - and the
|
|
|
|
average method is equal to 'weighted'.
|
2020-10-14 17:01:43 +00:00
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
This would mean that the sum of weights equals zero, and would, without
|
|
|
|
taking care of this case, return NaN. However, the reduction function
|
|
|
|
should catch that and set the metric to equal the value of zero_division
|
|
|
|
in this case (zero_division is for now not configurable and equals 0).
|
|
|
|
"""
|
2020-10-14 17:01:43 +00:00
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
preds = torch.tensor([1, 1, 0, 0])
|
|
|
|
target = torch.tensor([0, 0, 0, 0])
|
2020-10-14 17:01:43 +00:00
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=0)
|
|
|
|
cl_metric(preds, target)
|
2020-10-14 17:01:43 +00:00
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
result_cl = cl_metric.compute()
|
|
|
|
result_fn = metric_fn(preds, target, average="weighted", num_classes=2, ignore_index=0)
|
2020-10-14 17:01:43 +00:00
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
assert result_cl == result_fn == 0
|
2020-10-14 17:01:43 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
2021-01-18 08:24:13 +00:00
|
|
|
"metric_class, metric_fn, sk_fn", [(Recall, recall, recall_score), (Precision, precision, precision_score)]
|
2020-10-14 17:01:43 +00:00
|
|
|
)
|
2021-01-18 08:24:13 +00:00
|
|
|
@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"])
|
|
|
|
@pytest.mark.parametrize("ignore_index", [None, 0])
|
2020-10-14 17:51:58 +00:00
|
|
|
@pytest.mark.parametrize(
|
2021-01-18 08:24:13 +00:00
|
|
|
"preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper",
|
|
|
|
[
|
2021-02-06 16:41:40 +00:00
|
|
|
(_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_prec_recall),
|
|
|
|
(_input_binary.preds, _input_binary.target, 1, False, None, _sk_prec_recall),
|
|
|
|
(_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_prec_recall),
|
|
|
|
(_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_prec_recall),
|
|
|
|
(_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_prec_recall),
|
|
|
|
(_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_prec_recall),
|
|
|
|
(_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_multidim_multiclass),
|
|
|
|
(
|
|
|
|
_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "global",
|
|
|
|
_sk_prec_recall_multidim_multiclass
|
|
|
|
),
|
|
|
|
(_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_multidim_multiclass),
|
|
|
|
(
|
|
|
|
_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "samplewise",
|
|
|
|
_sk_prec_recall_multidim_multiclass
|
|
|
|
),
|
2021-01-18 08:24:13 +00:00
|
|
|
],
|
2020-10-14 17:51:58 +00:00
|
|
|
)
|
|
|
|
class TestPrecisionRecall(MetricTester):
|
2021-02-06 16:41:40 +00:00
|
|
|
|
2021-01-18 08:24:13 +00:00
|
|
|
@pytest.mark.parametrize("ddp", [False])
|
|
|
|
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
|
|
|
|
def test_precision_recall_class(
|
|
|
|
self,
|
|
|
|
ddp: bool,
|
|
|
|
dist_sync_on_step: bool,
|
|
|
|
preds: torch.Tensor,
|
|
|
|
target: torch.Tensor,
|
|
|
|
sk_wrapper: Callable,
|
|
|
|
metric_class: Metric,
|
|
|
|
metric_fn: Callable,
|
|
|
|
sk_fn: Callable,
|
|
|
|
is_multiclass: Optional[bool],
|
|
|
|
num_classes: Optional[int],
|
|
|
|
average: str,
|
|
|
|
mdmc_average: Optional[str],
|
|
|
|
ignore_index: Optional[int],
|
2020-10-14 17:51:58 +00:00
|
|
|
):
|
2021-01-18 08:24:13 +00:00
|
|
|
if num_classes == 1 and average != "micro":
|
|
|
|
pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)")
|
|
|
|
|
|
|
|
if ignore_index is not None and preds.ndim == 2:
|
|
|
|
pytest.skip("Skipping ignore_index test with binary inputs.")
|
|
|
|
|
|
|
|
if average == "weighted" and ignore_index is not None and mdmc_average is not None:
|
|
|
|
pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average")
|
|
|
|
|
2020-10-21 22:05:59 +00:00
|
|
|
self.run_class_metric_test(
|
2020-10-14 17:51:58 +00:00
|
|
|
ddp=ddp,
|
|
|
|
preds=preds,
|
|
|
|
target=target,
|
|
|
|
metric_class=metric_class,
|
2021-01-18 08:24:13 +00:00
|
|
|
sk_metric=partial(
|
|
|
|
sk_wrapper,
|
|
|
|
sk_fn=sk_fn,
|
|
|
|
average=average,
|
|
|
|
num_classes=num_classes,
|
|
|
|
is_multiclass=is_multiclass,
|
|
|
|
ignore_index=ignore_index,
|
|
|
|
mdmc_average=mdmc_average,
|
|
|
|
),
|
2020-10-14 17:51:58 +00:00
|
|
|
dist_sync_on_step=dist_sync_on_step,
|
|
|
|
metric_args={
|
|
|
|
"num_classes": num_classes,
|
|
|
|
"average": average,
|
|
|
|
"threshold": THRESHOLD,
|
2021-01-18 08:24:13 +00:00
|
|
|
"is_multiclass": is_multiclass,
|
|
|
|
"ignore_index": ignore_index,
|
|
|
|
"mdmc_average": mdmc_average,
|
2020-10-14 17:51:58 +00:00
|
|
|
},
|
2021-01-18 08:24:13 +00:00
|
|
|
check_dist_sync_on_step=True,
|
|
|
|
check_batch=True,
|
2020-10-14 17:51:58 +00:00
|
|
|
)
|
2021-01-18 08:24:13 +00:00
|
|
|
|
|
|
|
def test_precision_recall_fn(
|
|
|
|
self,
|
|
|
|
preds: torch.Tensor,
|
|
|
|
target: torch.Tensor,
|
|
|
|
sk_wrapper: Callable,
|
|
|
|
metric_class: Metric,
|
|
|
|
metric_fn: Callable,
|
|
|
|
sk_fn: Callable,
|
|
|
|
is_multiclass: Optional[bool],
|
|
|
|
num_classes: Optional[int],
|
|
|
|
average: str,
|
|
|
|
mdmc_average: Optional[str],
|
|
|
|
ignore_index: Optional[int],
|
|
|
|
):
|
|
|
|
if num_classes == 1 and average != "micro":
|
|
|
|
pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)")
|
|
|
|
|
|
|
|
if ignore_index is not None and preds.ndim == 2:
|
|
|
|
pytest.skip("Skipping ignore_index test with binary inputs.")
|
|
|
|
|
|
|
|
if average == "weighted" and ignore_index is not None and mdmc_average is not None:
|
|
|
|
pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average")
|
|
|
|
|
|
|
|
self.run_functional_metric_test(
|
|
|
|
preds,
|
|
|
|
target,
|
|
|
|
metric_functional=metric_fn,
|
|
|
|
sk_metric=partial(
|
|
|
|
sk_wrapper,
|
|
|
|
sk_fn=sk_fn,
|
|
|
|
average=average,
|
|
|
|
num_classes=num_classes,
|
|
|
|
is_multiclass=is_multiclass,
|
|
|
|
ignore_index=ignore_index,
|
|
|
|
mdmc_average=mdmc_average,
|
|
|
|
),
|
|
|
|
metric_args={
|
|
|
|
"num_classes": num_classes,
|
|
|
|
"average": average,
|
|
|
|
"threshold": THRESHOLD,
|
|
|
|
"is_multiclass": is_multiclass,
|
|
|
|
"ignore_index": ignore_index,
|
|
|
|
"mdmc_average": mdmc_average,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"])
|
|
|
|
def test_precision_recall_joint(average):
|
|
|
|
"""A simple test of the joint precision_recall metric.
|
|
|
|
|
|
|
|
No need to test this thorougly, as it is just a combination of precision and recall,
|
|
|
|
which are already tested thoroughly.
|
|
|
|
"""
|
|
|
|
|
2021-02-06 16:41:40 +00:00
|
|
|
precision_result = precision(
|
|
|
|
_input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
|
|
|
|
)
|
|
|
|
recall_result = recall(
|
|
|
|
_input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
|
|
|
|
)
|
2021-01-18 08:24:13 +00:00
|
|
|
|
|
|
|
prec_recall_result = precision_recall(
|
2021-02-06 16:41:40 +00:00
|
|
|
_input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
|
2021-01-18 08:24:13 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
assert torch.equal(precision_result, prec_recall_result[0])
|
|
|
|
assert torch.equal(recall_result, prec_recall_result[1])
|
|
|
|
|
|
|
|
|
|
|
|
_mc_k_target = torch.tensor([0, 1, 2])
|
|
|
|
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
|
|
|
|
_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]])
|
|
|
|
_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]])
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)])
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"k, preds, target, average, expected_prec, expected_recall",
|
|
|
|
[
|
|
|
|
(1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)),
|
|
|
|
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1 / 2), torch.tensor(1.0)),
|
|
|
|
(1, _ml_k_preds, _ml_k_target, "micro", torch.tensor(0.0), torch.tensor(0.0)),
|
|
|
|
(2, _ml_k_preds, _ml_k_target, "micro", torch.tensor(1 / 6), torch.tensor(1 / 3)),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_top_k(
|
|
|
|
metric_class,
|
|
|
|
metric_fn,
|
|
|
|
k: int,
|
|
|
|
preds: torch.Tensor,
|
|
|
|
target: torch.Tensor,
|
|
|
|
average: str,
|
|
|
|
expected_prec: torch.Tensor,
|
|
|
|
expected_recall: torch.Tensor,
|
|
|
|
):
|
|
|
|
"""A simple test to check that top_k works as expected.
|
|
|
|
|
|
|
|
Just a sanity check, the tests in StatScores should already guarantee
|
|
|
|
the corectness of results.
|
|
|
|
"""
|
|
|
|
|
|
|
|
class_metric = metric_class(top_k=k, average=average, num_classes=3)
|
|
|
|
class_metric.update(preds, target)
|
|
|
|
|
|
|
|
if metric_class.__name__ == "Precision":
|
|
|
|
result = expected_prec
|
|
|
|
else:
|
|
|
|
result = expected_recall
|
|
|
|
|
|
|
|
assert torch.equal(class_metric.compute(), result)
|
|
|
|
assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result)
|