From 96785b99df878ebc45bc455eb7902a6dd4661139 Mon Sep 17 00:00:00 2001 From: Dusan Drevicky <55678224+ddrevicky@users.noreply.github.com> Date: Wed, 21 Oct 2020 15:01:24 +0200 Subject: [PATCH] Feature/4244 iou input expectations (#4261) * =Add iou input checks * =Add test for iou input checks * =Update docstring for iou pred and target --- pytorch_lightning/metrics/functional/classification.py | 10 ++++++++-- tests/metrics/functional/test_classification.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index e87bfb7a00..abf661e46f 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -1002,8 +1002,8 @@ def iou( Intersection over union, or Jaccard index calculation. Args: - pred: Tensor containing predictions - target: Tensor containing targets + pred: Tensor containing integer predictions, with shape [N, d1, d2, ...] + target: Tensor containing integer targets, with shape [N, d1, d2, ...] ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. Has no effect if given an int that is not in the range [0, num_classes-1], where num_classes is either given or derived from pred and target. By default, no @@ -1032,6 +1032,12 @@ def iou( tensor(0.9660) """ + if pred.size() != target.size(): + raise ValueError(f"'pred' shape ({pred.size()}) must equal 'target' shape ({target.size()})") + + if not torch.allclose(pred.float(), pred.int().float()): + raise ValueError("'pred' must contain integer targets.") + num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred, target, num_classes) diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 9afdf84fa8..139aeea8cc 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -406,6 +406,16 @@ def test_iou(half_ones, reduction, ignore_index, expected): assert torch.allclose(iou_val, expected, atol=1e-9) +def test_iou_input_check(): + with pytest.raises(ValueError, match=r"'pred' shape (.*) must equal 'target' shape (.*)"): + _ = iou(pred=torch.randint(0, 2, (3, 4, 3)), + target=torch.randint(0, 2, (3, 3))) + + with pytest.raises(ValueError, match="'pred' must contain integer targets."): + _ = iou(pred=torch.rand((3, 3)), + target=torch.randint(0, 2, (3, 3))) + + @pytest.mark.parametrize('metric', [auroc]) def test_error_on_multiclass_input(metric): """ check that these metrics raise an error if they are used for multiclass problems """