lightning/pytorch_lightning/metrics/classification/helpers.py

540 lines
24 KiB
Python
Raw Normal View History

Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
import numpy as np
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
import torch
from pytorch_lightning.metrics.utils import select_topk, to_onehot
from pytorch_lightning.utilities import LightningEnum
class DataType(LightningEnum):
"""
Enum to represent data type
"""
BINARY = "binary"
MULTILABEL = "multi-label"
MULTICLASS = "multi-class"
MULTIDIM_MULTICLASS = "multi-dim multi-class"
class AverageMethod(LightningEnum):
"""
Enum to represent average method
"""
MICRO = "micro"
MACRO = "macro"
WEIGHTED = "weighted"
NONE = "none"
SAMPLES = "samples"
class MDMCAverageMethod(LightningEnum):
"""
Enum to represent multi-dim multi-class average method
"""
GLOBAL = "global"
SAMPLEWISE = "samplewise"
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool):
"""
Perform basic validation of inputs that does not require deducing any information
of the type of inputs.
"""
if target.is_floating_point():
raise ValueError("The `target` has to be an integer tensor.")
if target.min() < 0:
raise ValueError("The `target` has to be a non-negative tensor.")
preds_float = preds.is_floating_point()
if not preds_float and preds.min() < 0:
raise ValueError("If `preds` are integers, they have to be non-negative.")
if not preds.shape[0] == target.shape[0]:
raise ValueError("The `preds` and `target` should have the same first dimension.")
if preds_float and (preds.min() < 0 or preds.max() > 1):
raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.")
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
if not 0 < threshold < 1:
raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}")
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
if is_multiclass is False and target.max() > 1:
raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.")
if is_multiclass is False and not preds_float and preds.max() > 1:
raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.")
def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) -> Tuple[str, int]:
"""
This checks that the shape and type of inputs are consistent with
each other and fall into one of the allowed input types (see the
documentation of docstring of ``_input_format_classification``). It does
not check for consistency of number of classes, other functions take
care of that.
It returns the name of the case in which the inputs fall, and the implied
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
number of classes (from the ``C`` dim for multi-class data, or extra dim(s) for
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
multi-label data).
"""
preds_float = preds.is_floating_point()
if preds.ndim == target.ndim:
if preds.shape != target.shape:
raise ValueError(
"The `preds` and `target` should have the same shape,",
f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.",
)
if preds_float and target.max() > 1:
raise ValueError(
"If `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary."
)
# Get the case
if preds.ndim == 1 and preds_float:
case = DataType.BINARY
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
elif preds.ndim == 1 and not preds_float:
case = DataType.MULTICLASS
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
elif preds.ndim > 1 and preds_float:
case = DataType.MULTILABEL
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
else:
case = DataType.MULTIDIM_MULTICLASS
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
implied_classes = preds[0].numel()
elif preds.ndim == target.ndim + 1:
if not preds_float:
raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.")
if preds.shape[2:] != target.shape[1:]:
raise ValueError(
"If `preds` have one dimension more than `target`, the shape of `preds` should be"
" (N, C, ...), and the shape of `target` should be (N, ...)."
)
implied_classes = preds.shape[1]
if preds.ndim == 2:
case = DataType.MULTICLASS
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
else:
case = DataType.MULTIDIM_MULTICLASS
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
else:
raise ValueError(
"Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)"
" and `preds` should be (N, C, ...)."
)
return case, implied_classes
def _check_num_classes_binary(num_classes: int, is_multiclass: bool):
"""
This checks that the consistency of `num_classes` with the data
and `is_multiclass` param for binary data.
"""
if num_classes > 2:
raise ValueError("Your data is binary, but `num_classes` is larger than 2.")
if num_classes == 2 and not is_multiclass:
raise ValueError(
"Your data is binary and `num_classes=2`, but `is_multiclass` is not True."
" Set it to True if you want to transform binary data to multi-class format."
)
if num_classes == 1 and is_multiclass:
raise ValueError(
"You have binary data and have set `is_multiclass=True`, but `num_classes` is 1."
2020-12-21 08:53:09 +00:00
" Either set `is_multiclass=None`(default) or set `num_classes=2`"
" to transform binary data to multi-class format."
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
)
def _check_num_classes_mc(
preds: torch.Tensor, target: torch.Tensor, num_classes: int, is_multiclass: bool, implied_classes: int
):
"""
This checks that the consistency of `num_classes` with the data
and `is_multiclass` param for (multi-dimensional) multi-class data.
"""
if num_classes == 1 and is_multiclass is not False:
raise ValueError(
"You have set `num_classes=1`, but predictions are integers."
" If you want to convert (multi-dimensional) multi-class data with 2 classes"
" to binary/multi-label, set `is_multiclass=False`."
)
if num_classes > 1:
if is_multiclass is False:
if implied_classes != num_classes:
raise ValueError(
"You have set `is_multiclass=False`, but the implied number of classes "
" (from shape of inputs) does not match `num_classes`. If you are trying to"
" transform multi-dim multi-class data with 2 classes to multi-label, `num_classes`"
" should be either None or the product of the size of extra dimensions (...)."
" See Input Types in Metrics documentation."
)
if num_classes <= target.max():
raise ValueError("The highest label in `target` should be smaller than `num_classes`.")
if num_classes <= preds.max():
raise ValueError("The highest label in `preds` should be smaller than `num_classes`.")
if preds.shape != target.shape and num_classes != implied_classes:
raise ValueError("The size of C dimension of `preds` does not match `num_classes`.")
def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes: int):
"""
This checks that the consistency of `num_classes` with the data
and `is_multiclass` param for multi-label data.
"""
if is_multiclass and num_classes != 2:
raise ValueError(
"Your have set `is_multiclass=True`, but `num_classes` is not equal to 2."
" If you are trying to transform multi-label data to 2 class multi-dimensional"
" multi-class, you should set `num_classes` to either 2 or None."
)
if not is_multiclass and num_classes != implied_classes:
raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.")
def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool):
if case == DataType.BINARY:
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
raise ValueError("You can not use `top_k` parameter with binary data.")
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("The `top_k` has to be an integer larger than 0.")
if not preds_float:
raise ValueError("You have set `top_k`, but you do not have probability predictions.")
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
if is_multiclass is False:
raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.")
if case == DataType.MULTILABEL and is_multiclass:
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
raise ValueError(
"If you want to transform multi-label data to 2 class multi-dimensional"
"multi-class data using `is_multiclass=True`, you can not use `top_k`."
)
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
if top_k >= implied_classes:
raise ValueError("The `top_k` has to be strictly smaller than the `C` dimension of `preds`.")
def _check_classification_inputs(
preds: torch.Tensor,
target: torch.Tensor,
threshold: float,
num_classes: Optional[int],
is_multiclass: bool,
top_k: Optional[int],
) -> str:
"""Performs error checking on inputs for classification.
This ensures that preds and target take one of the shape/type combinations that are
specified in ``_input_format_classification`` docstring. It also checks the cases of
over-rides with ``is_multiclass`` by checking (for multi-class and multi-dim multi-class
cases) that there are only up to 2 distinct labels.
In case where preds are floats (probabilities), it is checked whether they are in [0,1] interval.
When ``num_classes`` is given, it is checked that it is consitent with input cases (binary,
multi-label, ...), and that, if availible, the implied number of classes in the ``C``
dimension is consistent with it (as well as that max label in target is smaller than it).
When ``num_classes`` is not specified in these cases, consistency of the highest target
value against ``C`` dimension is checked for (multi-dimensional) multi-class cases.
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
If ``top_k`` is set (not None) for inputs that do not have probability predictions (and
are not binary), an error is raised. Similarly if ``top_k`` is set to a number that
is higher than or equal to the ``C`` dimension of ``preds``, an error is raised.
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
Preds and target tensors are expected to be squeezed already - all dimensions should be
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
greater than 1, except perhaps the first one (``N``).
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
Args:
preds: Tensor with predictions (labels or probabilities)
target: Tensor with ground truth labels, always integers (labels)
threshold:
Threshold probability value for transforming probability predictions to binary
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
(0,1) predictions, in the case of binary or multi-label inputs.
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
num_classes:
Number of classes. If not explicitly set, the number of classes will be infered
either from the shape of inputs, or the maximum label in the ``target`` and ``preds``
tensor, where applicable.
top_k:
Number of highest probability entries for each sample to convert to 1s - relevant
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
only for inputs with probability predictions. The default value (``None``) will be
interepreted as 1 for these inputs. If this parameter is set for multi-label inputs,
it will take precedence over threshold.
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
Should be left unset (``None``) for inputs with label predictions.
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
is_multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
than what they appear to be. See the parameter's
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
for a more detailed explanation and examples.
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
Return:
case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or
'multi-dim multi-class'
"""
# Baisc validation (that does not need case/type information)
_basic_input_validation(preds, target, threshold, is_multiclass)
# Check that shape/types fall into one of the cases
case, implied_classes = _check_shape_and_type_consistency(preds, target)
# For (multi-dim) multi-class case with prob preds, check that preds sum up to 1
if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and preds.is_floating_point():
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all():
raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.")
# Check consistency with the `C` dimension in case of multi-class data
if preds.shape != target.shape:
if is_multiclass is False and implied_classes != 2:
raise ValueError(
"You have set `is_multiclass=False`, but have more than 2 classes in your data,"
" based on the C dimension of `preds`."
)
if target.max() >= implied_classes:
raise ValueError(
"The highest label in `target` should be smaller than the size of the `C` dimension of `preds`."
)
# Check that num_classes is consistent
if num_classes:
if case == DataType.BINARY:
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
_check_num_classes_binary(num_classes, is_multiclass)
elif case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS):
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
_check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes)
elif case.MULTILABEL:
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
_check_num_classes_ml(num_classes, is_multiclass, implied_classes)
# Check that top_k is consistent
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
if top_k is not None:
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
_check_top_k(top_k, case, implied_classes, is_multiclass, preds.is_floating_point())
return case
def _input_format_classification(
preds: torch.Tensor,
target: torch.Tensor,
threshold: float = 0.5,
top_k: Optional[int] = None,
num_classes: Optional[int] = None,
is_multiclass: Optional[bool] = None,
) -> Tuple[torch.Tensor, torch.Tensor, str]:
"""Convert preds and target tensors into common format.
Preds and targets are supposed to fall into one of these categories (and are
validated to make sure this is the case):
* Both preds and target are of shape ``(N,)``, and both are integers (multi-class)
* Both preds and target are of shape ``(N,)``, and target is binary, while preds
are a float (binary)
* preds are of shape ``(N, C)`` and are floats, and target is of shape ``(N,)`` and
is integer (multi-class)
* preds and target are of shape ``(N, ...)``, target is binary and preds is a float
(multi-label)
* preds are of shape ``(N, C, ...)`` and are floats, target is of shape ``(N, ...)``
and is integer (multi-dimensional multi-class)
* preds and target are of shape ``(N, ...)`` both are integers (multi-dimensional
multi-class)
To avoid ambiguities, all dimensions of size 1, except the first one, are squeezed out.
The returned output tensors will be binary tensors of the same shape, either ``(N, C)``
of ``(N, C, X)``, the details for each case are described below. The function also returns
a ``case`` string, which describes which of the above cases the inputs belonged to - regardless
of whether this was "overridden" by other settings (like ``is_multiclass``).
In binary case, targets are normally returned as ``(N,1)`` tensor, while preds are transformed
into a binary tensor (elements become 1 if the probability is greater than or equal to
``threshold`` or 0 otherwise). If ``is_multiclass=True``, then then both targets are preds
become ``(N, 2)`` tensors by a one-hot transformation; with the thresholding being applied to
preds first.
In multi-class case, normally both preds and targets become ``(N, C)`` binary tensors; targets
by a one-hot transformation and preds by selecting ``top_k`` largest entries (if their original
shape was ``(N,C)``). However, if ``is_multiclass=False``, then targets and preds will be
returned as ``(N,1)`` tensor.
In multi-label case, normally targets and preds are returned as ``(N, C)`` binary tensors, with
preds being binarized as in the binary case. Here the ``C`` dimension is obtained by flattening
all dimensions after the first one. However if ``is_multiclass=True``, then both are returned as
``(N, 2, C)``, by an equivalent transformation as in the binary case.
In multi-dimensional multi-class case, normally both target and preds are returned as
``(N, C, X)`` tensors, with ``X`` resulting from flattening of all dimensions except ``N`` and
``C``. The transformations performed here are equivalent to the multi-class case. However, if
``is_multiclass=False`` (and there are up to two classes), then the data is returned as
``(N, X)`` binary tensors (multi-label).
Note that where a one-hot transformation needs to be performed and the number of classes
is not implicitly given by a ``C`` dimension, the new ``C`` dimension will either be
equal to ``num_classes``, if it is given, or the maximum label value in preds and
target.
Args:
preds: Tensor with predictions (labels or probabilities)
target: Tensor with ground truth labels, always integers (labels)
threshold:
Threshold probability value for transforming probability predictions to binary
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
(0 or 1) predictions, in the case of binary or multi-label inputs.
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
num_classes:
Number of classes. If not explicitly set, the number of classes will be infered
either from the shape of inputs, or the maximum label in the ``target`` and ``preds``
tensor, where applicable.
top_k:
Number of highest probability entries for each sample to convert to 1s - relevant
only for (multi-dimensional) multi-class inputs with probability predictions. The
default value (``None``) will be interepreted as 1 for these inputs.
Should be left unset (``None``) for all other types of inputs.
is_multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
than what they appear to be. See the parameter's
:ref:`documentation section <extensions/metrics:using the is_multiclass parameter>`
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
for a more detailed explanation and examples.
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
Returns:
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
preds: binary tensor of shape ``(N, C)`` or ``(N, C, X)``
target: binary tensor of shape ``(N, C)`` or ``(N, C, X)``
case: The case the inputs fall in, one of ``'binary'``, ``'multi-class'``, ``'multi-label'`` or
``'multi-dim multi-class'``
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
"""
# Remove excess dimensions
if preds.shape[0] == 1:
preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0)
else:
preds, target = preds.squeeze(), target.squeeze()
Classification metrics overhaul: accuracy metrics (2/n) (#4838) * Add stuff * Change metrics documentation layout * Add stuff * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * Division with float * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Suggestions from code review * Fix number in docs * Update pytorch_lightning/metrics/classification/accuracy.py * Replace topk by argsort in select_topk * Fix changelog * Add test for wrong params * Add Google Colab badges (#5111) * Add colab badges to notebook Add colab badges to notebook to notebooks 4 & 5 * Add colab badges Co-authored-by: chaton <thomas@grid.ai> * Fix hanging metrics tests (#5134) * Use torch.topk again as ddp hanging tests fixed in #5134 * Fix unwanted notebooks change * Fix too long line in hamming_distance * Apply suggestions from code review * Apply suggestions from code review * protect * Update CHANGELOG.md Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Roger Shieh <sh.rog@protonmail.ch> Co-authored-by: Shachar Mirkin <shacharmirkin@gmail.com>
2020-12-21 15:42:51 +00:00
# Convert half precision tensors to full precision, as not all ops are supported
# for example, min() is not supported
if preds.dtype == torch.float16:
preds = preds.float()
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
case = _check_classification_inputs(
preds,
target,
threshold=threshold,
num_classes=num_classes,
is_multiclass=is_multiclass,
top_k=top_k,
)
if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k:
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
preds = (preds >= threshold).int()
num_classes = num_classes if not is_multiclass else 2
if case == DataType.MULTILABEL and top_k:
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
preds = select_topk(preds, top_k)
if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) or is_multiclass:
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
if preds.is_floating_point():
num_classes = preds.shape[1]
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
preds = select_topk(preds, top_k or 1)
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
else:
num_classes = num_classes if num_classes else max(preds.max(), target.max()) + 1
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
preds = to_onehot(preds, max(2, num_classes))
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
Classification metrics overhaul: stat scores (3/n) (#4839) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update to new top_k default * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Split accuracy and hamming loss * Remove old redundant accuracy * Minor changes * Fix imports * Improve docstring descriptions * Fix imports * Fix edge case and simplify testing * Fix docs * PEP8 * Reorder imports * Add top_k parameter * Update changelog * Update docstring * Update docstring * Reverse formatting changes for tests * Change parameter order * Remove formatting changes 2/2 * Remove formatting 3/3 * . * Improve description of top_k parameter * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Update pytorch_lightning/metrics/functional/accuracy.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Remove unneeded assert * Explicit checking of parameter values * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Apply suggestions from code review * Fix top_k checking * PEP8 * Don't check dist_sync in test * add back check_dist_sync_on_step * Make sure half-precision inputs are transformed (#5013) * Fix typo * Rename hamming loss to hamming distance * Fix tests for half precision * Fix docs underline length * Fix doc undeline length * Replace mdmc_accuracy parameter with subset_accuracy * Update changelog * Fix unwanted accuracy change * Enable top_k for ML prob inputs * Test that default threshold is 0.5 * Fix typo * Update top_k description in helpers * updates * Update styling and add back tests * Remove excess spaces * fix torch.where for old versions * fix linting * Update docstring * Fix docstring * Apply suggestions from code review (mostly docs) * Default threshold to None, accept only (0,1) * Change wrong threshold message * Improve documentation and add tests * Add back ddp tests * Change stat reduce method and default * Remove DDP tests and fix doctests * Fix doctest * Update changelog * Refactoring * Fix typo * Refactor * Increase coverage * Fix linting * Consistent use of backticks * Fix too long line in docs * Apply suggestions from code review * Fix deprecation test * Fix deprecation test * Default threshold back to 0.5 * Minor documentation fixes * Add types to tests Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-12-30 19:49:50 +00:00
target = to_onehot(target, max(2, num_classes))
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
if is_multiclass is False:
preds, target = preds[:, 1, ...], target[:, 1, ...]
if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and is_multiclass is not False) or is_multiclass:
Classification metrics overhaul: input formatting standardization (1/n) (#4837) * Add stuff * Change metrics documentation layout * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Break down error checking for inputs into separate functions * Remove the (N, ..., C) option in MD-MC * Simplify select_topk * Remove detach for inputs * Fix typos * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Minor error message changes * Update pytorch_lightning/metrics/utils.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Reuse case from validation in formatting * Refactor code in _input_format_classification * Small improvements * PEP 8 * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Update pytorch_lightning/metrics/classification/utils.py Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * Alphabetical reordering of regression metrics * Change default value of top_k and add error checking * Extract basic validation into separate function * Update desciption of parameters in input formatting * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Check that probabilities in preds sum to 1 (for MC) * Fix coverage * Minor changes * Fix edge case and simplify testing Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
2020-12-07 16:49:35 +00:00
target = target.reshape(target.shape[0], target.shape[1], -1)
preds = preds.reshape(preds.shape[0], preds.shape[1], -1)
else:
target = target.reshape(target.shape[0], -1)
preds = preds.reshape(preds.shape[0], -1)
# Some operatins above create an extra dimension for MC/binary case - this removes it
if preds.ndim > 2:
preds, target = preds.squeeze(-1), target.squeeze(-1)
return preds.int(), target.int(), case
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
def _reduce_stat_scores(
numerator: torch.Tensor,
denominator: torch.Tensor,
weights: Optional[torch.Tensor],
average: str,
mdmc_average: Optional[str],
zero_division: int = 0,
) -> torch.Tensor:
"""
Reduces scores of type ``numerator/denominator`` or
``weights * (numerator/denominator)``, if ``average='weighted'``.
Args:
numerator: A tensor with numerator numbers.
denominator: A tensor with denominator numbers. If a denominator is
negative, the class will be ignored (if averaging), or its score
will be returned as ``nan`` (if ``average=None``).
If the denominator is zero, then ``zero_division`` score will be
used for those elements.
weights:
A tensor of weights to be used if ``average='weighted'``.
average:
The method to average the scores. Should be one of ``'micro'``, ``'macro'``,
``'weighted'``, ``'none'``, ``None`` or ``'samples'``. The behavior
corresponds to `sklearn averaging methods <https://scikit-learn.org/stable/modules/\
model_evaluation.html#multiclass-and-multilabel-classification>`__.
mdmc_average:
The method to average the scores if inputs were multi-dimensional multi-class (MDMC).
Should be either ``'global'`` or ``'samplewise'``. If inputs were not
multi-dimensional multi-class, it should be ``None`` (default).
zero_division:
The value to use for the score if denominator equals zero.
"""
numerator, denominator = numerator.float(), denominator.float()
zero_div_mask = denominator == 0
ignore_mask = denominator < 0
if weights is None:
weights = torch.ones_like(denominator)
else:
weights = weights.float()
numerator = torch.where(zero_div_mask, torch.tensor(float(zero_division), device=numerator.device), numerator)
denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator)
weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights)
if average not in (AverageMethod.MICRO, AverageMethod.NONE, None):
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
weights = weights / weights.sum(dim=-1, keepdim=True)
scores = weights * (numerator / denominator)
# This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted'
scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores)
if mdmc_average == MDMCAverageMethod.SAMPLEWISE:
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
scores = scores.mean(dim=0)
ignore_mask = ignore_mask.sum(dim=0).bool()
if average in (AverageMethod.NONE, None):
Classification metrics overhaul: precision & recall (4/n) (#4842) * Add stuff * Change metrics documentation layout * Add stuff * Add stat scores * Change testing utils * Replace len(*.shape) with *.ndim * More descriptive error message for input formatting * Replace movedim with permute * PEP 8 compliance * WIP * Add reduce_scores function * Temporarily add back legacy class_reduce * Division with float * PEP 8 compliance * Remove precision recall * Replace movedim with permute * Add back tests * Add empty newlines * Add precision recall back * Add empty line * Fix permute * Fix some issues with old versions of PyTorch * Style changes in error messages * More error message style improvements * Fix typo in docs * Add more descriptive variable names in utils * Change internal var names * Revert unwanted changes * Revert unwanted changes pt 2 * Update metrics interface * Add top_k parameter * Add back reduce function * Add stuff * PEP3 * Add depreciation * PEP8 * Deprecate param * PEP8 * Fix and simplify testing for older PT versions * Update Changelog * Remove redundant import * Add tests to increase coverage * Remove zero_division * fix zero_division * Add zero_div + edge case tests * Reorder cls metric args * Add back quotes for is_multiclass * Add precision_recall and tests * PEP8 * Fix docs * Fix docs * Update * Change precision_recall output * PEP8/isort * Add method _get_final_stats * Fix depr test * Add comment to deprecation tests * isort * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add typing to test * Add matc str to pytest.raises Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-01-18 08:24:13 +00:00
scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores)
else:
scores = scores.sum()
return scores