lightning/tests/metrics/classification/test_precision_recall.py

348 lines
13 KiB
Python

from functools import partial
from typing import Callable, Optional
import numpy as np
import pytest
import torch
from sklearn.metrics import precision_score, recall_score
from pytorch_lightning.metrics import Metric, Precision, Recall
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.functional import precision, precision_recall, recall
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.metrics.classification.inputs import _input_multilabel as _input_mlb
from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD
torch.manual_seed(42)
def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average=None):
if average == "none":
average = None
if num_classes == 1:
average = "binary"
labels = list(range(num_classes))
try:
labels.remove(ignore_index)
except ValueError:
pass
sk_preds, sk_target, _ = _input_format_classification(
preds, target, THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
)
sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()
sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=0, labels=labels)
if len(labels) != num_classes and not average:
sk_scores = np.insert(sk_scores, ignore_index, np.nan)
return sk_scores
def _sk_prec_recall_multidim_multiclass(
preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average
):
preds, target, _ = _input_format_classification(
preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
)
if mdmc_average == "global":
preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1])
target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1])
return _sk_prec_recall(preds, target, sk_fn, num_classes, average, False, ignore_index)
elif mdmc_average == "samplewise":
scores = []
for i in range(preds.shape[0]):
pred_i = preds[i, ...].T
target_i = target[i, ...].T
scores_i = _sk_prec_recall(pred_i, target_i, sk_fn, num_classes, average, False, ignore_index)
scores.append(np.expand_dims(scores_i, 0))
return np.concatenate(scores).mean(axis=0)
@pytest.mark.parametrize("metric, fn_metric", [(Precision, precision), (Recall, recall)])
@pytest.mark.parametrize(
"average, mdmc_average, num_classes, ignore_index, match_str",
[
("wrong", None, None, None, "`average`"),
("micro", "wrong", None, None, "`mdmc"),
("macro", None, None, None, "number of classes"),
("macro", None, 1, 0, "ignore_index"),
],
)
def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ignore_index, match_str):
with pytest.raises(ValueError, match=match_str):
metric(
average=average,
mdmc_average=mdmc_average,
num_classes=num_classes,
ignore_index=ignore_index,
)
with pytest.raises(ValueError, match=match_str):
fn_metric(
_input_binary.preds[0],
_input_binary.target[0],
average=average,
mdmc_average=mdmc_average,
num_classes=num_classes,
ignore_index=ignore_index,
)
with pytest.raises(ValueError, match=match_str):
precision_recall(
_input_binary.preds[0],
_input_binary.target[0],
average=average,
mdmc_average=mdmc_average,
num_classes=num_classes,
ignore_index=ignore_index,
)
@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)])
def test_zero_division(metric_class, metric_fn):
""" Test that zero_division works correctly (currently should just set to 0). """
preds = torch.tensor([1, 2, 1, 1])
target = torch.tensor([2, 1, 2, 1])
cl_metric = metric_class(average="none", num_classes=3)
cl_metric(preds, target)
result_cl = cl_metric.compute()
result_fn = metric_fn(preds, target, average="none", num_classes=3)
assert result_cl[0] == result_fn[0] == 0
@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)])
def test_no_support(metric_class, metric_fn):
"""This tests a rare edge case, where there is only one class present
in target, and ignore_index is set to exactly that class - and the
average method is equal to 'weighted'.
This would mean that the sum of weights equals zero, and would, without
taking care of this case, return NaN. However, the reduction function
should catch that and set the metric to equal the value of zero_division
in this case (zero_division is for now not configurable and equals 0).
"""
preds = torch.tensor([1, 1, 0, 0])
target = torch.tensor([0, 0, 0, 0])
cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=0)
cl_metric(preds, target)
result_cl = cl_metric.compute()
result_fn = metric_fn(preds, target, average="weighted", num_classes=2, ignore_index=0)
assert result_cl == result_fn == 0
@pytest.mark.parametrize(
"metric_class, metric_fn, sk_fn", [(Recall, recall, recall_score), (Precision, precision, precision_score)]
)
@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"])
@pytest.mark.parametrize("ignore_index", [None, 0])
@pytest.mark.parametrize(
"preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper",
[
(_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_prec_recall),
(_input_binary.preds, _input_binary.target, 1, False, None, _sk_prec_recall),
(_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_prec_recall),
(_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_prec_recall),
(_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_prec_recall),
(_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_prec_recall),
(_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_multidim_multiclass),
(
_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "global",
_sk_prec_recall_multidim_multiclass
),
(_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_multidim_multiclass),
(
_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "samplewise",
_sk_prec_recall_multidim_multiclass
),
],
)
class TestPrecisionRecall(MetricTester):
@pytest.mark.parametrize("ddp", [False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_precision_recall_class(
self,
ddp: bool,
dist_sync_on_step: bool,
preds: torch.Tensor,
target: torch.Tensor,
sk_wrapper: Callable,
metric_class: Metric,
metric_fn: Callable,
sk_fn: Callable,
is_multiclass: Optional[bool],
num_classes: Optional[int],
average: str,
mdmc_average: Optional[str],
ignore_index: Optional[int],
):
if num_classes == 1 and average != "micro":
pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)")
if ignore_index is not None and preds.ndim == 2:
pytest.skip("Skipping ignore_index test with binary inputs.")
if average == "weighted" and ignore_index is not None and mdmc_average is not None:
pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average")
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=metric_class,
sk_metric=partial(
sk_wrapper,
sk_fn=sk_fn,
average=average,
num_classes=num_classes,
is_multiclass=is_multiclass,
ignore_index=ignore_index,
mdmc_average=mdmc_average,
),
dist_sync_on_step=dist_sync_on_step,
metric_args={
"num_classes": num_classes,
"average": average,
"threshold": THRESHOLD,
"is_multiclass": is_multiclass,
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
check_dist_sync_on_step=True,
check_batch=True,
)
def test_precision_recall_fn(
self,
preds: torch.Tensor,
target: torch.Tensor,
sk_wrapper: Callable,
metric_class: Metric,
metric_fn: Callable,
sk_fn: Callable,
is_multiclass: Optional[bool],
num_classes: Optional[int],
average: str,
mdmc_average: Optional[str],
ignore_index: Optional[int],
):
if num_classes == 1 and average != "micro":
pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)")
if ignore_index is not None and preds.ndim == 2:
pytest.skip("Skipping ignore_index test with binary inputs.")
if average == "weighted" and ignore_index is not None and mdmc_average is not None:
pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average")
self.run_functional_metric_test(
preds,
target,
metric_functional=metric_fn,
sk_metric=partial(
sk_wrapper,
sk_fn=sk_fn,
average=average,
num_classes=num_classes,
is_multiclass=is_multiclass,
ignore_index=ignore_index,
mdmc_average=mdmc_average,
),
metric_args={
"num_classes": num_classes,
"average": average,
"threshold": THRESHOLD,
"is_multiclass": is_multiclass,
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
)
@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"])
def test_precision_recall_joint(average):
"""A simple test of the joint precision_recall metric.
No need to test this thorougly, as it is just a combination of precision and recall,
which are already tested thoroughly.
"""
precision_result = precision(
_input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
)
recall_result = recall(
_input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
)
prec_recall_result = precision_recall(
_input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
)
assert torch.equal(precision_result, prec_recall_result[0])
assert torch.equal(recall_result, prec_recall_result[1])
_mc_k_target = torch.tensor([0, 1, 2])
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]])
_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]])
@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)])
@pytest.mark.parametrize(
"k, preds, target, average, expected_prec, expected_recall",
[
(1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)),
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1 / 2), torch.tensor(1.0)),
(1, _ml_k_preds, _ml_k_target, "micro", torch.tensor(0.0), torch.tensor(0.0)),
(2, _ml_k_preds, _ml_k_target, "micro", torch.tensor(1 / 6), torch.tensor(1 / 3)),
],
)
def test_top_k(
metric_class,
metric_fn,
k: int,
preds: torch.Tensor,
target: torch.Tensor,
average: str,
expected_prec: torch.Tensor,
expected_recall: torch.Tensor,
):
"""A simple test to check that top_k works as expected.
Just a sanity check, the tests in StatScores should already guarantee
the corectness of results.
"""
class_metric = metric_class(top_k=k, average=average, num_classes=3)
class_metric.update(preds, target)
if metric_class.__name__ == "Precision":
result = expected_prec
else:
result = expected_recall
assert torch.equal(class_metric.compute(), result)
assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result)