diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 8af2a7ca1e..43dda00fc0 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -1024,11 +1024,11 @@ def iou( Example: - >>> target = torch.randint(0, 1, (10, 25, 25)) + >>> target = torch.randint(0, 2, (10, 25, 25)) >>> pred = torch.tensor(target) >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] >>> iou(pred, target) - tensor(0.4914) + tensor(0.9660) """ num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)