# NOTE: This file only tests if modules with arguments are running fine. # The actual metric implementation is tested in functional/test_classification.py # Especially reduction and reducing across processes won't be tested here! import pytest import torch from pytorch_lightning.metrics.classification import ( Accuracy, ConfusionMatrix, PrecisionRecallCurve, Precision, Recall, AveragePrecision, AUROC, FBeta, F1, ROC, MulticlassROC, MulticlassPrecisionRecallCurve, DiceCoefficient, IoU, ) @pytest.fixture def random(): torch.manual_seed(0) @pytest.mark.parametrize('num_classes', [1, None]) def test_accuracy(num_classes): acc = Accuracy(num_classes=num_classes) assert acc.name == 'accuracy' result = acc(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]), target=torch.tensor([[0, 0, 1], [1, 0, 1]])) assert isinstance(result, torch.Tensor) @pytest.mark.parametrize(['normalize', 'num_classes'], [ pytest.param(False, None), pytest.param(True, None), pytest.param(False, 3) ]) def test_confusion_matrix(normalize, num_classes): conf_matrix = ConfusionMatrix(normalize=normalize, num_classes=num_classes) assert conf_matrix.name == 'confusion_matrix' target = (torch.arange(120) % 3).view(-1, 1) pred = target.clone() cm = conf_matrix(pred, target) assert isinstance(cm, torch.Tensor) @pytest.mark.parametrize(['normalize', 'num_classes'], [ pytest.param(True, 3) ]) def test_confusion_matrix_norm(normalize, num_classes): """ test that user is warned if confusion matrix contains nans that are changed to zeros""" conf_matrix = ConfusionMatrix(normalize=normalize, num_classes=num_classes) assert conf_matrix.name == 'confusion_matrix' with pytest.warns(UserWarning, match='6 nan values found in confusion matrix have been replaced with zeros.'): target = torch.LongTensor([0] * 5) pred = target.clone() cm = conf_matrix(pred, target) assert isinstance(cm, torch.Tensor) @pytest.mark.parametrize('pos_label', [1, 2.]) def test_precision_recall(pos_label): pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1]) pr_curve = PrecisionRecallCurve(pos_label=pos_label) assert pr_curve.name == 'precision_recall_curve' pr = pr_curve(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4]) assert isinstance(pr, tuple) assert len(pr) == 3 for tmp in pr: assert isinstance(tmp, torch.Tensor) @pytest.mark.parametrize('num_classes', [1, None]) def test_precision(num_classes): precision = Precision(num_classes=num_classes) assert precision.name == 'precision' pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1]) prec = precision(pred=pred, target=target) assert isinstance(prec, torch.Tensor) @pytest.mark.parametrize('num_classes', [1, None]) def test_recall(num_classes): recall = Recall(num_classes=num_classes) assert recall.name == 'recall' pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1]) rec = recall(pred=pred, target=target) assert isinstance(rec, torch.Tensor) @pytest.mark.parametrize('pos_label', [1, 2]) def test_average_precision(pos_label): avg_prec = AveragePrecision(pos_label=pos_label) assert avg_prec.name == 'AP' pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 0, 1]) ap = avg_prec(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4]) assert isinstance(ap, torch.Tensor) @pytest.mark.parametrize('pos_label', [0, 1]) def test_auroc(pos_label): auroc = AUROC(pos_label=pos_label) assert auroc.name == 'auroc' pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 1, 0, 1]) area = auroc(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4]) assert isinstance(area, torch.Tensor) @pytest.mark.parametrize(['beta', 'num_classes'], [ pytest.param(0., 1), pytest.param(0.5, 1), pytest.param(1., 1), pytest.param(2., 1), pytest.param(0., None), pytest.param(0.5, None), pytest.param(1., None), pytest.param(2., None) ]) def test_fbeta(beta, num_classes): fbeta = FBeta(beta=beta, num_classes=num_classes) assert fbeta.name == 'fbeta' score = fbeta(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]), target=torch.tensor([[0, 0, 1], [1, 0, 1]])) assert isinstance(score, torch.Tensor) @pytest.mark.parametrize('num_classes', [1, None]) def test_f1(num_classes): f1 = F1(num_classes=num_classes) assert f1.name == 'f1' score = f1(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]), target=torch.tensor([[0, 0, 1], [1, 0, 1]])) assert isinstance(score, torch.Tensor) @pytest.mark.parametrize('pos_label', [1, 2]) def test_roc(pos_label): roc = ROC(pos_label=pos_label) assert roc.name == 'roc' pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 4, 3]) res = roc(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4]) assert isinstance(res, tuple) assert len(res) == 3 for tmp in res: assert isinstance(tmp, torch.Tensor) @pytest.mark.parametrize('num_classes', [4, None]) def test_multiclass_roc(num_classes): pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], [0.05, 0.85, 0.05, 0.05], [0.05, 0.05, 0.85, 0.05], [0.05, 0.05, 0.05, 0.85]]) target = torch.tensor([0, 1, 3, 2]) multi_roc = MulticlassROC(num_classes=num_classes) assert multi_roc.name == 'multiclass_roc' res = multi_roc(pred, target) assert isinstance(res, tuple) if num_classes is not None: assert len(res) == num_classes for tmp in res: assert isinstance(tmp, tuple) assert len(tmp) == 3 for _tmp in tmp: assert isinstance(_tmp, torch.Tensor) @pytest.mark.parametrize('num_classes', [4, None]) def test_multiclass_pr(num_classes): pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], [0.05, 0.85, 0.05, 0.05], [0.05, 0.05, 0.85, 0.05], [0.05, 0.05, 0.05, 0.85]]) target = torch.tensor([0, 1, 3, 2]) multi_pr = MulticlassPrecisionRecallCurve(num_classes=num_classes) assert multi_pr.name == 'multiclass_precision_recall_curve' pr = multi_pr(pred, target) assert isinstance(pr, tuple) if num_classes is not None: assert len(pr) == num_classes for tmp in pr: assert isinstance(tmp, tuple) assert len(tmp) == 3 for _tmp in tmp: assert isinstance(_tmp, torch.Tensor) @pytest.mark.parametrize('include_background', [True, False]) def test_dice_coefficient(include_background): dice_coeff = DiceCoefficient(include_background=include_background) assert dice_coeff.name == 'dice' dice = dice_coeff(torch.randint(0, 1, (10, 25, 25)), torch.randint(0, 1, (10, 25, 25))) assert isinstance(dice, torch.Tensor) @pytest.mark.parametrize('ignore_index', [0, 1, None]) def test_iou(ignore_index): iou = IoU(ignore_index=ignore_index) assert iou.name == 'iou' score = iou(torch.randint(0, 1, (10, 25, 25)), torch.randint(0, 1, (10, 25, 25))) assert isinstance(score, torch.Tensor)