lightning/tests/metrics/classification/test_precision_recall.py

348 lines
13 KiB
Python
Raw Normal View History

from functools import partial
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
from typing import Callable, Optional
import numpy as np
import pytest
import torch
from sklearn.metrics import precision_score, recall_score
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
from pytorch_lightning.metrics import Metric, Precision, Recall
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.functional import precision, precision_recall, recall
2021-02-06 16:41:40 +00:00
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mcls_prob
from tests.metrics.classification.inputs import _input_multidim_multiclass as _input_mdmc
from tests.metrics.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob
from tests.metrics.classification.inputs import _input_multilabel as _input_mlb
from tests.metrics.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.metrics.utils import MetricTester, NUM_CLASSES, THRESHOLD
torch.manual_seed(42)
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
def _sk_prec_recall(preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average=None):
if average == "none":
average = None
if num_classes == 1:
average = "binary"
labels = list(range(num_classes))
try:
labels.remove(ignore_index)
except ValueError:
pass
sk_preds, sk_target, _ = _input_format_classification(
preds, target, THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
)
sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()
sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=0, labels=labels)
if len(labels) != num_classes and not average:
sk_scores = np.insert(sk_scores, ignore_index, np.nan)
return sk_scores
2021-02-06 16:41:40 +00:00
def _sk_prec_recall_multidim_multiclass(
preds, target, sk_fn, num_classes, average, is_multiclass, ignore_index, mdmc_average
):
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
preds, target, _ = _input_format_classification(
preds, target, threshold=THRESHOLD, num_classes=num_classes, is_multiclass=is_multiclass
)
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
if mdmc_average == "global":
preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1])
target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1])
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
return _sk_prec_recall(preds, target, sk_fn, num_classes, average, False, ignore_index)
elif mdmc_average == "samplewise":
scores = []
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
for i in range(preds.shape[0]):
pred_i = preds[i, ...].T
target_i = target[i, ...].T
scores_i = _sk_prec_recall(pred_i, target_i, sk_fn, num_classes, average, False, ignore_index)
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
scores.append(np.expand_dims(scores_i, 0))
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
return np.concatenate(scores).mean(axis=0)
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
@pytest.mark.parametrize("metric, fn_metric", [(Precision, precision), (Recall, recall)])
@pytest.mark.parametrize(
"average, mdmc_average, num_classes, ignore_index, match_str",
[
("wrong", None, None, None, "`average`"),
("micro", "wrong", None, None, "`mdmc"),
("macro", None, None, None, "number of classes"),
("macro", None, 1, 0, "ignore_index"),
],
)
def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ignore_index, match_str):
with pytest.raises(ValueError, match=match_str):
metric(
average=average,
mdmc_average=mdmc_average,
num_classes=num_classes,
ignore_index=ignore_index,
)
with pytest.raises(ValueError, match=match_str):
fn_metric(
2021-02-06 16:41:40 +00:00
_input_binary.preds[0],
_input_binary.target[0],
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
average=average,
mdmc_average=mdmc_average,
num_classes=num_classes,
ignore_index=ignore_index,
)
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
with pytest.raises(ValueError, match=match_str):
precision_recall(
2021-02-06 16:41:40 +00:00
_input_binary.preds[0],
_input_binary.target[0],
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
average=average,
mdmc_average=mdmc_average,
num_classes=num_classes,
ignore_index=ignore_index,
)
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)])
def test_zero_division(metric_class, metric_fn):
""" Test that zero_division works correctly (currently should just set to 0). """
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
preds = torch.tensor([1, 2, 1, 1])
target = torch.tensor([2, 1, 2, 1])
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
cl_metric = metric_class(average="none", num_classes=3)
cl_metric(preds, target)
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
result_cl = cl_metric.compute()
result_fn = metric_fn(preds, target, average="none", num_classes=3)
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
assert result_cl[0] == result_fn[0] == 0
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)])
def test_no_support(metric_class, metric_fn):
"""This tests a rare edge case, where there is only one class present
in target, and ignore_index is set to exactly that class - and the
average method is equal to 'weighted'.
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
This would mean that the sum of weights equals zero, and would, without
taking care of this case, return NaN. However, the reduction function
should catch that and set the metric to equal the value of zero_division
in this case (zero_division is for now not configurable and equals 0).
"""
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
preds = torch.tensor([1, 1, 0, 0])
target = torch.tensor([0, 0, 0, 0])
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=0)
cl_metric(preds, target)
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
result_cl = cl_metric.compute()
result_fn = metric_fn(preds, target, average="weighted", num_classes=2, ignore_index=0)
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
assert result_cl == result_fn == 0
@pytest.mark.parametrize(
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
"metric_class, metric_fn, sk_fn", [(Recall, recall, recall_score), (Precision, precision, precision_score)]
)
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"])
@pytest.mark.parametrize("ignore_index", [None, 0])
@pytest.mark.parametrize(
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
"preds, target, num_classes, is_multiclass, mdmc_average, sk_wrapper",
[
2021-02-06 16:41:40 +00:00
(_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_prec_recall),
(_input_binary.preds, _input_binary.target, 1, False, None, _sk_prec_recall),
(_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_prec_recall),
(_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_prec_recall),
(_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_prec_recall),
(_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_prec_recall),
(_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_multidim_multiclass),
(
_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "global",
_sk_prec_recall_multidim_multiclass
),
(_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_multidim_multiclass),
(
_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "samplewise",
_sk_prec_recall_multidim_multiclass
),
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
],
)
class TestPrecisionRecall(MetricTester):
2021-02-06 16:41:40 +00:00
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
@pytest.mark.parametrize("ddp", [False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_precision_recall_class(
self,
ddp: bool,
dist_sync_on_step: bool,
preds: torch.Tensor,
target: torch.Tensor,
sk_wrapper: Callable,
metric_class: Metric,
metric_fn: Callable,
sk_fn: Callable,
is_multiclass: Optional[bool],
num_classes: Optional[int],
average: str,
mdmc_average: Optional[str],
ignore_index: Optional[int],
):
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
if num_classes == 1 and average != "micro":
pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)")
if ignore_index is not None and preds.ndim == 2:
pytest.skip("Skipping ignore_index test with binary inputs.")
if average == "weighted" and ignore_index is not None and mdmc_average is not None:
pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average")
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=metric_class,
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
sk_metric=partial(
sk_wrapper,
sk_fn=sk_fn,
average=average,
num_classes=num_classes,
is_multiclass=is_multiclass,
ignore_index=ignore_index,
mdmc_average=mdmc_average,
),
dist_sync_on_step=dist_sync_on_step,
metric_args={
"num_classes": num_classes,
"average": average,
"threshold": THRESHOLD,
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
"is_multiclass": is_multiclass,
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
check_dist_sync_on_step=True,
check_batch=True,
)
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
def test_precision_recall_fn(
self,
preds: torch.Tensor,
target: torch.Tensor,
sk_wrapper: Callable,
metric_class: Metric,
metric_fn: Callable,
sk_fn: Callable,
is_multiclass: Optional[bool],
num_classes: Optional[int],
average: str,
mdmc_average: Optional[str],
ignore_index: Optional[int],
):
if num_classes == 1 and average != "micro":
pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)")
if ignore_index is not None and preds.ndim == 2:
pytest.skip("Skipping ignore_index test with binary inputs.")
if average == "weighted" and ignore_index is not None and mdmc_average is not None:
pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average")
self.run_functional_metric_test(
preds,
target,
metric_functional=metric_fn,
sk_metric=partial(
sk_wrapper,
sk_fn=sk_fn,
average=average,
num_classes=num_classes,
is_multiclass=is_multiclass,
ignore_index=ignore_index,
mdmc_average=mdmc_average,
),
metric_args={
"num_classes": num_classes,
"average": average,
"threshold": THRESHOLD,
"is_multiclass": is_multiclass,
"ignore_index": ignore_index,
"mdmc_average": mdmc_average,
},
)
@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"])
def test_precision_recall_joint(average):
"""A simple test of the joint precision_recall metric.
No need to test this thorougly, as it is just a combination of precision and recall,
which are already tested thoroughly.
"""
2021-02-06 16:41:40 +00:00
precision_result = precision(
_input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
)
recall_result = recall(
_input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
)
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
prec_recall_result = precision_recall(
2021-02-06 16:41:40 +00:00
_input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
)
assert torch.equal(precision_result, prec_recall_result[0])
assert torch.equal(recall_result, prec_recall_result[1])
_mc_k_target = torch.tensor([0, 1, 2])
_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]])
_ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]])
_ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]])
@pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)])
@pytest.mark.parametrize(
"k, preds, target, average, expected_prec, expected_recall",
[
(1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)),
(2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(1 / 2), torch.tensor(1.0)),
(1, _ml_k_preds, _ml_k_target, "micro", torch.tensor(0.0), torch.tensor(0.0)),
(2, _ml_k_preds, _ml_k_target, "micro", torch.tensor(1 / 6), torch.tensor(1 / 3)),
],
)
def test_top_k(
metric_class,
metric_fn,
k: int,
preds: torch.Tensor,
target: torch.Tensor,
average: str,
expected_prec: torch.Tensor,
expected_recall: torch.Tensor,
):
"""A simple test to check that top_k works as expected.
Just a sanity check, the tests in StatScores should already guarantee
the corectness of results.
"""
class_metric = metric_class(top_k=k, average=average, num_classes=3)
class_metric.update(preds, target)
if metric_class.__name__ == "Precision":
result = expected_prec
else:
result = expected_recall
assert torch.equal(class_metric.compute(), result)
assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result)