fix normalize mode at confusion matrix (replace nans with zeros) (#3465)

* replace nans to 0 at conf. matrix & update tests

* cm.isnan() -> torch.isnan(cm)

* fix row-wise division while normalize

* update tests

* pep8 fix

* Update tests/metrics/test_classification.py

add comment to test

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Update tests/metrics/functional/test_classification.py

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Update pytorch_lightning/metrics/functional/classification.py

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* final update

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
This commit is contained in:
Cookie_thief 2020-09-14 11:05:51 +03:00 committed by GitHub
parent 87fc43e6b5
commit a552d4a2d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 31 additions and 1 deletions

View File

@ -312,7 +312,11 @@ def confusion_matrix(
cm = bins.reshape(num_classes, num_classes).squeeze().float()
if normalize:
cm = cm / cm.sum(-1)
cm = cm / cm.sum(-1, keepdim=True)
nan_elements = cm[torch.isnan(cm)].nelement()
if nan_elements != 0:
cm[torch.isnan(cm)] = 0
rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.')
return cm

View File

@ -187,6 +187,17 @@ def test_confusion_matrix():
cm = confusion_matrix(pred, target, normalize=False, num_classes=3)
assert torch.allclose(cm, torch.tensor([[5., 0., 0.], [0., 0., 0.], [0., 0., 0.]]))
# Example taken from https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
target = torch.LongTensor([0] * 13 + [1] * 16 + [2] * 9)
pred = torch.LongTensor([0] * 13 + [1] * 10 + [2] * 15)
cm = confusion_matrix(pred, target, normalize=False, num_classes=3)
assert torch.allclose(cm, torch.tensor([[13., 0., 0.], [0., 10., 6.], [0., 0., 9.]]))
to_compare = cm / torch.tensor([[13.], [16.], [9.]])
cm = confusion_matrix(pred, target, normalize=True, num_classes=3)
assert torch.allclose(cm, to_compare)
@pytest.mark.parametrize(['pred', 'target', 'expected_prec', 'expected_rec'], [
pytest.param(torch.tensor([1., 0., 1., 0.]), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]),

View File

@ -53,6 +53,21 @@ def test_confusion_matrix(normalize, num_classes):
assert isinstance(cm, torch.Tensor)
@pytest.mark.parametrize(['normalize', 'num_classes'], [
pytest.param(True, 3)
])
def test_confusion_matrix_norm(normalize, num_classes):
""" test that user is warned if confusion matrix contains nans that are changed to zeros"""
conf_matrix = ConfusionMatrix(normalize=normalize, num_classes=num_classes)
assert conf_matrix.name == 'confusion_matrix'
with pytest.warns(UserWarning, match='6 nan values found in confusion matrix have been replaced with zeros.'):
target = torch.LongTensor([0] * 5)
pred = target.clone()
cm = conf_matrix(pred, target)
assert isinstance(cm, torch.Tensor)
@pytest.mark.parametrize('pos_label', [1, 2.])
def test_precision_recall(pos_label):
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1])