Feature/4244 iou input expectations (#4261)
* =Add iou input checks * =Add test for iou input checks * =Update docstring for iou pred and target
This commit is contained in:
parent
d27ee8b5bf
commit
96785b99df
|
@ -1002,8 +1002,8 @@ def iou(
|
|||
Intersection over union, or Jaccard index calculation.
|
||||
|
||||
Args:
|
||||
pred: Tensor containing predictions
|
||||
target: Tensor containing targets
|
||||
pred: Tensor containing integer predictions, with shape [N, d1, d2, ...]
|
||||
target: Tensor containing integer targets, with shape [N, d1, d2, ...]
|
||||
ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute
|
||||
to the returned score, regardless of reduction method. Has no effect if given an int that is not in the
|
||||
range [0, num_classes-1], where num_classes is either given or derived from pred and target. By default, no
|
||||
|
@ -1032,6 +1032,12 @@ def iou(
|
|||
tensor(0.9660)
|
||||
|
||||
"""
|
||||
if pred.size() != target.size():
|
||||
raise ValueError(f"'pred' shape ({pred.size()}) must equal 'target' shape ({target.size()})")
|
||||
|
||||
if not torch.allclose(pred.float(), pred.int().float()):
|
||||
raise ValueError("'pred' must contain integer targets.")
|
||||
|
||||
num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)
|
||||
|
||||
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred, target, num_classes)
|
||||
|
|
|
@ -406,6 +406,16 @@ def test_iou(half_ones, reduction, ignore_index, expected):
|
|||
assert torch.allclose(iou_val, expected, atol=1e-9)
|
||||
|
||||
|
||||
def test_iou_input_check():
|
||||
with pytest.raises(ValueError, match=r"'pred' shape (.*) must equal 'target' shape (.*)"):
|
||||
_ = iou(pred=torch.randint(0, 2, (3, 4, 3)),
|
||||
target=torch.randint(0, 2, (3, 3)))
|
||||
|
||||
with pytest.raises(ValueError, match="'pred' must contain integer targets."):
|
||||
_ = iou(pred=torch.rand((3, 3)),
|
||||
target=torch.randint(0, 2, (3, 3)))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('metric', [auroc])
|
||||
def test_error_on_multiclass_input(metric):
|
||||
""" check that these metrics raise an error if they are used for multiclass problems """
|
||||
|
|
Loading…
Reference in New Issue