Feature/4244 iou input expectations (#4261)

* =Add iou input checks

* =Add test for iou input checks

* =Update docstring for iou pred and target
This commit is contained in:
Dusan Drevicky 2020-10-21 15:01:24 +02:00 committed by GitHub
parent d27ee8b5bf
commit 96785b99df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 2 deletions

View File

@ -1002,8 +1002,8 @@ def iou(
Intersection over union, or Jaccard index calculation.
Args:
pred: Tensor containing predictions
target: Tensor containing targets
pred: Tensor containing integer predictions, with shape [N, d1, d2, ...]
target: Tensor containing integer targets, with shape [N, d1, d2, ...]
ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute
to the returned score, regardless of reduction method. Has no effect if given an int that is not in the
range [0, num_classes-1], where num_classes is either given or derived from pred and target. By default, no
@ -1032,6 +1032,12 @@ def iou(
tensor(0.9660)
"""
if pred.size() != target.size():
raise ValueError(f"'pred' shape ({pred.size()}) must equal 'target' shape ({target.size()})")
if not torch.allclose(pred.float(), pred.int().float()):
raise ValueError("'pred' must contain integer targets.")
num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred, target, num_classes)

View File

@ -406,6 +406,16 @@ def test_iou(half_ones, reduction, ignore_index, expected):
assert torch.allclose(iou_val, expected, atol=1e-9)
def test_iou_input_check():
with pytest.raises(ValueError, match=r"'pred' shape (.*) must equal 'target' shape (.*)"):
_ = iou(pred=torch.randint(0, 2, (3, 4, 3)),
target=torch.randint(0, 2, (3, 3)))
with pytest.raises(ValueError, match="'pred' must contain integer targets."):
_ = iou(pred=torch.rand((3, 3)),
target=torch.randint(0, 2, (3, 3)))
@pytest.mark.parametrize('metric', [auroc])
def test_error_on_multiclass_input(metric):
""" check that these metrics raise an error if they are used for multiclass problems """