fix normalize mode at confusion matrix (replace nans with zeros) (#3465)
* replace nans to 0 at conf. matrix & update tests * cm.isnan() -> torch.isnan(cm) * fix row-wise division while normalize * update tests * pep8 fix * Update tests/metrics/test_classification.py add comment to test Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Update tests/metrics/functional/test_classification.py Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Update pytorch_lightning/metrics/functional/classification.py Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * final update Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
This commit is contained in:
parent
87fc43e6b5
commit
a552d4a2d5
|
@ -312,7 +312,11 @@ def confusion_matrix(
|
|||
cm = bins.reshape(num_classes, num_classes).squeeze().float()
|
||||
|
||||
if normalize:
|
||||
cm = cm / cm.sum(-1)
|
||||
cm = cm / cm.sum(-1, keepdim=True)
|
||||
nan_elements = cm[torch.isnan(cm)].nelement()
|
||||
if nan_elements != 0:
|
||||
cm[torch.isnan(cm)] = 0
|
||||
rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.')
|
||||
|
||||
return cm
|
||||
|
||||
|
|
|
@ -187,6 +187,17 @@ def test_confusion_matrix():
|
|||
cm = confusion_matrix(pred, target, normalize=False, num_classes=3)
|
||||
assert torch.allclose(cm, torch.tensor([[5., 0., 0.], [0., 0., 0.], [0., 0., 0.]]))
|
||||
|
||||
# Example taken from https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
|
||||
target = torch.LongTensor([0] * 13 + [1] * 16 + [2] * 9)
|
||||
pred = torch.LongTensor([0] * 13 + [1] * 10 + [2] * 15)
|
||||
cm = confusion_matrix(pred, target, normalize=False, num_classes=3)
|
||||
assert torch.allclose(cm, torch.tensor([[13., 0., 0.], [0., 10., 6.], [0., 0., 9.]]))
|
||||
to_compare = cm / torch.tensor([[13.], [16.], [9.]])
|
||||
|
||||
cm = confusion_matrix(pred, target, normalize=True, num_classes=3)
|
||||
assert torch.allclose(cm, to_compare)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_prec', 'expected_rec'], [
|
||||
pytest.param(torch.tensor([1., 0., 1., 0.]), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]),
|
||||
|
|
|
@ -53,6 +53,21 @@ def test_confusion_matrix(normalize, num_classes):
|
|||
assert isinstance(cm, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['normalize', 'num_classes'], [
|
||||
pytest.param(True, 3)
|
||||
])
|
||||
def test_confusion_matrix_norm(normalize, num_classes):
|
||||
""" test that user is warned if confusion matrix contains nans that are changed to zeros"""
|
||||
conf_matrix = ConfusionMatrix(normalize=normalize, num_classes=num_classes)
|
||||
assert conf_matrix.name == 'confusion_matrix'
|
||||
|
||||
with pytest.warns(UserWarning, match='6 nan values found in confusion matrix have been replaced with zeros.'):
|
||||
target = torch.LongTensor([0] * 5)
|
||||
pred = target.clone()
|
||||
cm = conf_matrix(pred, target)
|
||||
assert isinstance(cm, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('pos_label', [1, 2.])
|
||||
def test_precision_recall(pos_label):
|
||||
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1])
|
||||
|
|
Loading…
Reference in New Issue