diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 527770b430..5a14aa22ed 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -312,7 +312,11 @@ def confusion_matrix( cm = bins.reshape(num_classes, num_classes).squeeze().float() if normalize: - cm = cm / cm.sum(-1) + cm = cm / cm.sum(-1, keepdim=True) + nan_elements = cm[torch.isnan(cm)].nelement() + if nan_elements != 0: + cm[torch.isnan(cm)] = 0 + rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.') return cm diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 8431846ba0..2074f70db5 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -187,6 +187,17 @@ def test_confusion_matrix(): cm = confusion_matrix(pred, target, normalize=False, num_classes=3) assert torch.allclose(cm, torch.tensor([[5., 0., 0.], [0., 0., 0.], [0., 0., 0.]])) + # Example taken from https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html + target = torch.LongTensor([0] * 13 + [1] * 16 + [2] * 9) + pred = torch.LongTensor([0] * 13 + [1] * 10 + [2] * 15) + cm = confusion_matrix(pred, target, normalize=False, num_classes=3) + assert torch.allclose(cm, torch.tensor([[13., 0., 0.], [0., 10., 6.], [0., 0., 9.]])) + to_compare = cm / torch.tensor([[13.], [16.], [9.]]) + + cm = confusion_matrix(pred, target, normalize=True, num_classes=3) + assert torch.allclose(cm, to_compare) + + @pytest.mark.parametrize(['pred', 'target', 'expected_prec', 'expected_rec'], [ pytest.param(torch.tensor([1., 0., 1., 0.]), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]), diff --git a/tests/metrics/test_classification.py b/tests/metrics/test_classification.py index fd6e1a3d1a..1be023b2af 100644 --- a/tests/metrics/test_classification.py +++ b/tests/metrics/test_classification.py @@ -53,6 +53,21 @@ def test_confusion_matrix(normalize, num_classes): assert isinstance(cm, torch.Tensor) +@pytest.mark.parametrize(['normalize', 'num_classes'], [ + pytest.param(True, 3) +]) +def test_confusion_matrix_norm(normalize, num_classes): + """ test that user is warned if confusion matrix contains nans that are changed to zeros""" + conf_matrix = ConfusionMatrix(normalize=normalize, num_classes=num_classes) + assert conf_matrix.name == 'confusion_matrix' + + with pytest.warns(UserWarning, match='6 nan values found in confusion matrix have been replaced with zeros.'): + target = torch.LongTensor([0] * 5) + pred = target.clone() + cm = conf_matrix(pred, target) + assert isinstance(cm, torch.Tensor) + + @pytest.mark.parametrize('pos_label', [1, 2.]) def test_precision_recall(pos_label): pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1])