lightning/tests/metrics/functional/test_classification.py

486 lines
22 KiB
Python

from functools import partial
import pytest
import torch
from sklearn.metrics import (
accuracy_score as sk_accuracy,
jaccard_score as sk_jaccard_score,
precision_score as sk_precision,
recall_score as sk_recall,
f1_score as sk_f1_score,
fbeta_score as sk_fbeta_score,
confusion_matrix as sk_confusion_matrix,
roc_curve as sk_roc_curve,
roc_auc_score as sk_roc_auc_score,
precision_recall_curve as sk_precision_recall_curve
)
from pytorch_lightning import seed_everything
from pytorch_lightning.metrics.functional.classification import (
to_onehot,
to_categorical,
get_num_classes,
stat_scores,
stat_scores_multiple_classes,
accuracy,
confusion_matrix,
precision,
recall,
fbeta_score,
f1_score,
_binary_clf_curve,
dice_score,
average_precision,
auroc,
precision_recall_curve,
roc,
auc,
iou,
)
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric', 'only_binary'], [
pytest.param(sk_accuracy, accuracy, False, id='accuracy'),
pytest.param(partial(sk_jaccard_score, average='macro'), iou, False, id='iou'),
pytest.param(partial(sk_precision, average='micro'), precision, False, id='precision'),
pytest.param(partial(sk_recall, average='micro'), recall, False, id='recall'),
pytest.param(partial(sk_f1_score, average='micro'), f1_score, False, id='f1_score'),
pytest.param(partial(sk_fbeta_score, average='micro', beta=2),
partial(fbeta_score, beta=2), False, id='fbeta_score'),
pytest.param(sk_confusion_matrix, confusion_matrix, False, id='confusion_matrix'),
pytest.param(sk_roc_curve, roc, True, id='roc'),
pytest.param(sk_precision_recall_curve, precision_recall_curve, True, id='precision_recall_curve'),
pytest.param(sk_roc_auc_score, auroc, True, id='auroc')
])
def test_against_sklearn(sklearn_metric, torch_metric, only_binary):
"""Compare PL metrics to sklearn version. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# for metrics with only_binary=False, we try out different combinations of number
# of labels in pred and target (also test binary)
# for metrics with only_binary=True, target is always binary and pred will be
# (unnormalized) class probabilities
class_comb = [(5, 2)] if only_binary else [(10, 10), (5, 10), (10, 5), (2, 2)]
for n_cls_pred, n_cls_target in class_comb:
pred = torch.randint(n_cls_pred, (300,), device=device)
target = torch.randint(n_cls_target, (300,), device=device)
sk_score = sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy())
pl_score = torch_metric(pred, target)
# if multi output
if isinstance(sk_score, tuple):
sk_score = [torch.tensor(sk_s.copy(), dtype=torch.float, device=device) for sk_s in sk_score]
for sk_s, pl_s in zip(sk_score, pl_score):
assert torch.allclose(sk_s, pl_s.float())
else:
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
assert torch.allclose(sk_score, pl_score)
@pytest.mark.parametrize('class_reduction', ['micro', 'macro', 'weighted'])
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
pytest.param(sk_precision, precision, id='precision'),
pytest.param(sk_recall, recall, id='recall'),
pytest.param(sk_f1_score, f1_score, id='f1_score'),
pytest.param(partial(sk_fbeta_score, beta=2), partial(fbeta_score, beta=2), id='fbeta_score')
])
def test_different_reduction_against_sklearn(class_reduction, sklearn_metric, torch_metric):
""" Test metrics where the class_reduction parameter have a correponding
value in sklearn """
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pred = torch.randint(10, (300,), device=device)
target = torch.randint(10, (300,), device=device)
sk_score = sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy(),
average=class_reduction)
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
pl_score = torch_metric(pred, target, class_reduction=class_reduction)
assert torch.allclose(sk_score, pl_score)
def test_onehot():
test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
expected = torch.stack([
torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]),
torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)])
])
assert test_tensor.shape == (2, 5)
assert expected.shape == (2, 10, 5)
onehot_classes = to_onehot(test_tensor, num_classes=10)
onehot_no_classes = to_onehot(test_tensor)
assert torch.allclose(onehot_classes, onehot_no_classes)
assert onehot_classes.shape == expected.shape
assert onehot_no_classes.shape == expected.shape
assert torch.allclose(expected.to(onehot_no_classes), onehot_no_classes)
assert torch.allclose(expected.to(onehot_classes), onehot_classes)
def test_to_categorical():
test_tensor = torch.stack([
torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]),
torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)])
]).to(torch.float)
expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
assert expected.shape == (2, 5)
assert test_tensor.shape == (2, 10, 5)
result = to_categorical(test_tensor)
assert result.shape == expected.shape
assert torch.allclose(result, expected.to(result.dtype))
@pytest.mark.parametrize(['pred', 'target', 'num_classes', 'expected_num_classes'], [
pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10),
pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
])
def test_get_num_classes(pred, target, num_classes, expected_num_classes):
assert get_num_classes(pred, target, num_classes) == expected_num_classes
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp',
'expected_tn', 'expected_fn', 'expected_support'], [
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2),
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2)
])
def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
tp, fp, tn, fn, sup = stat_scores(pred, target, class_index=4)
assert tp.item() == expected_tp
assert fp.item() == expected_fp
assert tn.item() == expected_tn
assert fn.item() == expected_fn
assert sup.item() == expected_support
@pytest.mark.parametrize(['pred', 'target', 'reduction', 'expected_tp', 'expected_fp',
'expected_tn', 'expected_fn', 'expected_support'], [
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 'none',
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]),
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'none',
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]),
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'sum',
torch.tensor(2), torch.tensor(2), torch.tensor(14), torch.tensor(2), torch.tensor(4)),
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'elementwise_mean',
torch.tensor(0.4), torch.tensor(0.4), torch.tensor(2.8), torch.tensor(0.4), torch.tensor(0.8))
])
def test_stat_scores_multiclass(pred, target, reduction, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target, reduction=reduction)
assert torch.allclose(torch.tensor(expected_tp).to(tp), tp)
assert torch.allclose(torch.tensor(expected_fp).to(fp), fp)
assert torch.allclose(torch.tensor(expected_tn).to(tn), tn)
assert torch.allclose(torch.tensor(expected_fn).to(fn), fn)
assert torch.allclose(torch.tensor(expected_support).to(sup), sup)
def test_multilabel_accuracy():
# Dense label indicator matrix format
y1 = torch.tensor([[0, 1, 1], [1, 0, 1]])
y2 = torch.tensor([[0, 0, 1], [1, 0, 1]])
assert torch.allclose(accuracy(y1, y2, class_reduction='none'), torch.tensor([2 / 3, 1.]))
assert torch.allclose(accuracy(y1, y1, class_reduction='none'), torch.tensor([1., 1.]))
assert torch.allclose(accuracy(y2, y2, class_reduction='none'), torch.tensor([1., 1.]))
assert torch.allclose(accuracy(y2, torch.logical_not(y2), class_reduction='none'), torch.tensor([0., 0.]))
assert torch.allclose(accuracy(y1, torch.logical_not(y1), class_reduction='none'), torch.tensor([0., 0.]))
# num_classes does not match extracted number from input we expect a warning
with pytest.warns(RuntimeWarning,
match=r'You have set .* number of classes which is'
r' different from predicted (.*) and'
r' target (.*) number of classes'):
_ = accuracy(y2, torch.zeros_like(y2), num_classes=3)
def test_accuracy():
pred = torch.tensor([0, 1, 2, 3])
target = torch.tensor([0, 1, 2, 2])
acc = accuracy(pred, target)
assert acc.item() == 0.75
pred = torch.tensor([0, 1, 2, 2])
target = torch.tensor([0, 1, 1, 3])
acc = accuracy(pred, target)
assert acc.item() == 0.50
def test_confusion_matrix():
target = (torch.arange(120) % 3).view(-1, 1)
pred = target.clone()
cm = confusion_matrix(pred, target, normalize=True)
assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]))
pred = torch.zeros_like(pred)
cm = confusion_matrix(pred, target, normalize=True)
assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]]))
target = torch.LongTensor([0, 0, 0, 0, 0])
pred = target.clone()
cm = confusion_matrix(pred, target, normalize=False, num_classes=3)
assert torch.allclose(cm, torch.tensor([[5., 0., 0.], [0., 0., 0.], [0., 0., 0.]]))
# Example taken from https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
target = torch.LongTensor([0] * 13 + [1] * 16 + [2] * 9)
pred = torch.LongTensor([0] * 13 + [1] * 10 + [2] * 15)
cm = confusion_matrix(pred, target, normalize=False, num_classes=3)
assert torch.allclose(cm, torch.tensor([[13., 0., 0.], [0., 10., 6.], [0., 0., 9.]]))
to_compare = cm / torch.tensor([[13.], [16.], [9.]])
cm = confusion_matrix(pred, target, normalize=True, num_classes=3)
assert torch.allclose(cm, to_compare)
@pytest.mark.parametrize(['pred', 'target', 'expected_prec', 'expected_rec'], [
pytest.param(torch.tensor([1., 0., 1., 0.]), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]),
pytest.param(to_onehot(torch.tensor([1., 0., 1., 0.])), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5])
])
def test_precision_recall(pred, target, expected_prec, expected_rec):
prec = precision(pred, target, class_reduction='none')
rec = recall(pred, target, class_reduction='none')
assert torch.allclose(torch.tensor(expected_prec).to(prec), prec)
assert torch.allclose(torch.tensor(expected_rec).to(rec), rec)
@pytest.mark.parametrize(['pred', 'target', 'beta', 'exp_score'], [
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 0.5, [0.5, 0.5]),
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 1, [0.5, 0.5]),
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 2, [0.5, 0.5]),
])
def test_fbeta_score(pred, target, beta, exp_score):
score = fbeta_score(torch.tensor(pred), torch.tensor(target), beta, class_reduction='none')
assert torch.allclose(score, torch.tensor(exp_score))
score = fbeta_score(to_onehot(torch.tensor(pred)), torch.tensor(target), beta, class_reduction='none')
assert torch.allclose(score, torch.tensor(exp_score))
@pytest.mark.parametrize(['pred', 'target', 'exp_score'], [
pytest.param([0., 0., 0., 0.], [1., 1., 1., 1.], [0.0, 0.0]),
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], [0.5, 0.5]),
pytest.param([1., 0., 1., 0.], [1., 0., 1., 0.], [1.0, 1.0]),
])
def test_f1_score(pred, target, exp_score):
score = f1_score(torch.tensor(pred), torch.tensor(target), class_reduction='none')
assert torch.allclose(score, torch.tensor(exp_score))
score = f1_score(to_onehot(torch.tensor(pred)), torch.tensor(target), class_reduction='none')
assert torch.allclose(score, torch.tensor(exp_score))
@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [
pytest.param(1, 1., 42),
pytest.param(None, 1., 42),
])
def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
# TODO: move back the pred and target to test func arguments
# if you fix the array inside the function, you'd also have fix the shape,
# because when the array changes, you also have to fix the shape
seed_everything(0)
pred = torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100
target = torch.tensor([0, 1] * 50, dtype=torch.int)
if sample_weight is not None:
sample_weight = torch.ones_like(pred) * sample_weight
fps, tps, thresh = _binary_clf_curve(pred, target, sample_weight, pos_label)
assert isinstance(tps, torch.Tensor)
assert isinstance(fps, torch.Tensor)
assert isinstance(thresh, torch.Tensor)
assert tps.shape == (exp_shape,)
assert fps.shape == (exp_shape,)
assert thresh.shape == (exp_shape,)
@pytest.mark.parametrize(['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], [
pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4])
])
def test_pr_curve(pred, target, expected_p, expected_r, expected_t):
p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target))
assert p.size() == r.size()
assert p.size(0) == t.size(0) + 1
assert torch.allclose(p, torch.tensor(expected_p).to(p))
assert torch.allclose(r, torch.tensor(expected_r).to(r))
assert torch.allclose(t, torch.tensor(expected_t).to(t))
@pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [
pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]),
pytest.param([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]),
pytest.param([1, 1], [1, 0], [0, 1], [0, 1]),
pytest.param([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]),
pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]),
])
def test_roc_curve(pred, target, expected_tpr, expected_fpr):
fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target))
assert fpr.shape == tpr.shape
assert fpr.size(0) == thresh.size(0)
assert torch.allclose(fpr, torch.tensor(expected_fpr).to(fpr))
assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr))
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
pytest.param([0, 1, 0, 1], [0, 1, 0, 1], 1.),
pytest.param([1, 1, 0, 0], [0, 0, 1, 1], 0.),
pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.5),
pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 1.),
pytest.param([0.5, 0.5, 0.5, 0.5], [1, 1, 0, 0], 0.5),
])
def test_auroc(pred, target, expected):
score = auroc(torch.tensor(pred), torch.tensor(target)).item()
assert score == expected
@pytest.mark.parametrize(['x', 'y', 'expected'], [
pytest.param([0, 1], [0, 1], 0.5),
pytest.param([1, 0], [0, 1], 0.5),
pytest.param([1, 0, 0], [0, 1, 1], 0.5),
pytest.param([0, 1], [1, 1], 1),
pytest.param([0, 0.5, 1], [0, 0.5, 1], 0.5),
])
def test_auc(x, y, expected):
# Test Area Under Curve (AUC) computation
assert auc(torch.tensor(x), torch.tensor(y)) == expected
@pytest.mark.parametrize(['scores', 'target', 'expected_score'], [
# Check the average_precision_score of a constant predictor is
# the TPR
# Generate a dataset with 25% of positives
# And a constant score
# The precision is then the fraction of positive whatever the recall
# is, as there is only one threshold:
pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25),
# With threshold 0.8 : 1 TP and 2 TN and one FN
pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75),
])
def test_average_precision(scores, target, expected_score):
assert average_precision(scores, target) == expected_score
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.),
pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.),
pytest.param([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3),
pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.),
])
def test_dice_score(pred, target, expected):
score = dice_score(torch.tensor(pred), torch.tensor(target))
assert score == expected
@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [
pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])),
pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])),
pytest.param(False, 'none', 0, torch.Tensor([1, 1])),
pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])),
pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])),
pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])),
])
def test_iou(half_ones, reduction, ignore_index, expected):
pred = (torch.arange(120) % 3).view(-1, 1)
target = (torch.arange(120) % 3).view(-1, 1)
if half_ones:
pred[:60] = 1
iou_val = iou(
pred=pred,
target=target,
ignore_index=ignore_index,
reduction=reduction,
)
assert torch.allclose(iou_val, expected, atol=1e-9)
@pytest.mark.parametrize('metric', [auroc])
def test_error_on_multiclass_input(metric):
""" check that these metrics raise an error if they are used for multiclass problems """
pred = torch.randint(0, 10, (100, ))
target = torch.randint(0, 10, (100, ))
with pytest.raises(ValueError, match="AUROC metric is meant for binary classification"):
_ = metric(pred, target)
# TODO: When the jaccard_score of the sklearn version we use accepts `zero_division` (see
# https://github.com/scikit-learn/scikit-learn/pull/17866), consider adding a test here against our
# `absent_score`.
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'absent_score', 'num_classes', 'expected'], [
# Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid
# scores the function can return ([0., 1.] range, inclusive).
# 2 classes, class 0 is correct everywhere, class 1 is absent.
pytest.param([0], [0], None, -1., 2, [1., -1.]),
pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]),
# absent_score not applied if only class 0 is present and it's the only class.
pytest.param([0], [0], None, -1., 1, [1.]),
# 2 classes, class 1 is correct everywhere, class 0 is absent.
pytest.param([1], [1], None, -1., 2, [-1., 1.]),
pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]),
# When 0 index ignored, class 0 does not get a score (not even the absent_score).
pytest.param([1], [1], 0, -1., 2, [1.0]),
# 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score.
pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]),
pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]),
# 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score.
pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]),
pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]),
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class
# 2 is absent.
pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]),
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class
# 2 is absent.
pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]),
# Sanity checks with absent_score of 1.0.
pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]),
pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]),
])
def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected):
iou_val = iou(
pred=torch.tensor(pred),
target=torch.tensor(target),
ignore_index=ignore_index,
absent_score=absent_score,
num_classes=num_classes,
reduction='none',
)
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))
# example data taken from
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], [
# Ignoring an index outside of [0, num_classes-1] should have no effect.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]),
# Ignoring a valid index drops only that index from the result.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]),
# When reducing to mean or sum, the ignored index does not contribute to the output.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]),
])
def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected):
iou_val = iou(
pred=torch.tensor(pred),
target=torch.tensor(target),
ignore_index=ignore_index,
num_classes=num_classes,
reduction=reduction,
)
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))