diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index e87bfb7a00..abf661e46f 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -1002,8 +1002,8 @@ def iou( Intersection over union, or Jaccard index calculation. Args: - pred: Tensor containing predictions - target: Tensor containing targets + pred: Tensor containing integer predictions, with shape [N, d1, d2, ...] + target: Tensor containing integer targets, with shape [N, d1, d2, ...] ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. Has no effect if given an int that is not in the range [0, num_classes-1], where num_classes is either given or derived from pred and target. By default, no @@ -1032,6 +1032,12 @@ def iou( tensor(0.9660) """ + if pred.size() != target.size(): + raise ValueError(f"'pred' shape ({pred.size()}) must equal 'target' shape ({target.size()})") + + if not torch.allclose(pred.float(), pred.int().float()): + raise ValueError("'pred' must contain integer targets.") + num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred, target, num_classes) diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 9afdf84fa8..139aeea8cc 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -406,6 +406,16 @@ def test_iou(half_ones, reduction, ignore_index, expected): assert torch.allclose(iou_val, expected, atol=1e-9) +def test_iou_input_check(): + with pytest.raises(ValueError, match=r"'pred' shape (.*) must equal 'target' shape (.*)"): + _ = iou(pred=torch.randint(0, 2, (3, 4, 3)), + target=torch.randint(0, 2, (3, 3))) + + with pytest.raises(ValueError, match="'pred' must contain integer targets."): + _ = iou(pred=torch.rand((3, 3)), + target=torch.randint(0, 2, (3, 3))) + + @pytest.mark.parametrize('metric', [auroc]) def test_error_on_multiclass_input(metric): """ check that these metrics raise an error if they are used for multiclass problems """