Update metrics to use Enum (#5689)

- Add DataType, AverageMethod and MDMCAverageMethod
This commit is contained in:
yuntai 2021-02-02 03:50:10 +09:00 committed by GitHub
parent 8943d8bca0
commit 50fd4879a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 97 additions and 57 deletions

View File

@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added ### Added
- Added `DataType`, `AverageMethod` and `MDMCAverageMethod` enum in metrics ([#5657](https://github.com/PyTorchLightning/pytorch-lightning/pull/5689)
- Added support for summarized model total params size in megabytes ([#5590](https://github.com/PyTorchLightning/pytorch-lightning/pull/5590)) - Added support for summarized model total params size in megabytes ([#5590](https://github.com/PyTorchLightning/pytorch-lightning/pull/5590))

View File

@ -17,6 +17,39 @@ import numpy as np
import torch import torch
from pytorch_lightning.metrics.utils import select_topk, to_onehot from pytorch_lightning.metrics.utils import select_topk, to_onehot
from pytorch_lightning.utilities import LightningEnum
class DataType(LightningEnum):
"""
Enum to represent data type
"""
BINARY = "binary"
MULTILABEL = "multi-label"
MULTICLASS = "multi-class"
MULTIDIM_MULTICLASS = "multi-dim multi-class"
class AverageMethod(LightningEnum):
"""
Enum to represent average method
"""
MICRO = "micro"
MACRO = "macro"
WEIGHTED = "weighted"
NONE = "none"
SAMPLES = "samples"
class MDMCAverageMethod(LightningEnum):
"""
Enum to represent multi-dim multi-class average method
"""
GLOBAL = "global"
SAMPLEWISE = "samplewise"
def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool): def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool):
@ -78,13 +111,13 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor)
# Get the case # Get the case
if preds.ndim == 1 and preds_float: if preds.ndim == 1 and preds_float:
case = "binary" case = DataType.BINARY
elif preds.ndim == 1 and not preds_float: elif preds.ndim == 1 and not preds_float:
case = "multi-class" case = DataType.MULTICLASS
elif preds.ndim > 1 and preds_float: elif preds.ndim > 1 and preds_float:
case = "multi-label" case = DataType.MULTILABEL
else: else:
case = "multi-dim multi-class" case = DataType.MULTIDIM_MULTICLASS
implied_classes = preds[0].numel() implied_classes = preds[0].numel()
@ -100,9 +133,9 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor)
implied_classes = preds.shape[1] implied_classes = preds.shape[1]
if preds.ndim == 2: if preds.ndim == 2:
case = "multi-class" case = DataType.MULTICLASS
else: else:
case = "multi-dim multi-class" case = DataType.MULTIDIM_MULTICLASS
else: else:
raise ValueError( raise ValueError(
"Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)"
@ -182,7 +215,7 @@ def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes
def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool): def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool):
if case == "binary": if case == DataType.BINARY:
raise ValueError("You can not use `top_k` parameter with binary data.") raise ValueError("You can not use `top_k` parameter with binary data.")
if not isinstance(top_k, int) or top_k <= 0: if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("The `top_k` has to be an integer larger than 0.") raise ValueError("The `top_k` has to be an integer larger than 0.")
@ -190,7 +223,7 @@ def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Opt
raise ValueError("You have set `top_k`, but you do not have probability predictions.") raise ValueError("You have set `top_k`, but you do not have probability predictions.")
if is_multiclass is False: if is_multiclass is False:
raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.") raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.")
if case == "multi-label" and is_multiclass: if case == DataType.MULTILABEL and is_multiclass:
raise ValueError( raise ValueError(
"If you want to transform multi-label data to 2 class multi-dimensional" "If you want to transform multi-label data to 2 class multi-dimensional"
"multi-class data using `is_multiclass=True`, you can not use `top_k`." "multi-class data using `is_multiclass=True`, you can not use `top_k`."
@ -266,7 +299,7 @@ def _check_classification_inputs(
case, implied_classes = _check_shape_and_type_consistency(preds, target) case, implied_classes = _check_shape_and_type_consistency(preds, target)
# For (multi-dim) multi-class case with prob preds, check that preds sum up to 1 # For (multi-dim) multi-class case with prob preds, check that preds sum up to 1
if "multi-class" in case and preds.is_floating_point(): if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and preds.is_floating_point():
if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all(): if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all():
raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.") raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.")
@ -284,11 +317,11 @@ def _check_classification_inputs(
# Check that num_classes is consistent # Check that num_classes is consistent
if num_classes: if num_classes:
if case == "binary": if case == DataType.BINARY:
_check_num_classes_binary(num_classes, is_multiclass) _check_num_classes_binary(num_classes, is_multiclass)
elif "multi-class" in case: elif case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS):
_check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes) _check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes)
elif case == "multi-label": elif case.MULTILABEL:
_check_num_classes_ml(num_classes, is_multiclass, implied_classes) _check_num_classes_ml(num_classes, is_multiclass, implied_classes)
# Check that top_k is consistent # Check that top_k is consistent
@ -406,14 +439,14 @@ def _input_format_classification(
top_k=top_k, top_k=top_k,
) )
if case in ["binary", "multi-label"] and not top_k: if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k:
preds = (preds >= threshold).int() preds = (preds >= threshold).int()
num_classes = num_classes if not is_multiclass else 2 num_classes = num_classes if not is_multiclass else 2
if case == "multi-label" and top_k: if case == DataType.MULTILABEL and top_k:
preds = select_topk(preds, top_k) preds = select_topk(preds, top_k)
if "multi-class" in case or is_multiclass: if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) or is_multiclass:
if preds.is_floating_point(): if preds.is_floating_point():
num_classes = preds.shape[1] num_classes = preds.shape[1]
preds = select_topk(preds, top_k or 1) preds = select_topk(preds, top_k or 1)
@ -426,7 +459,7 @@ def _input_format_classification(
if is_multiclass is False: if is_multiclass is False:
preds, target = preds[:, 1, ...], target[:, 1, ...] preds, target = preds[:, 1, ...], target[:, 1, ...]
if ("multi-class" in case and is_multiclass is not False) or is_multiclass: if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and is_multiclass is not False) or is_multiclass:
target = target.reshape(target.shape[0], target.shape[1], -1) target = target.reshape(target.shape[0], target.shape[1], -1)
preds = preds.reshape(preds.shape[0], preds.shape[1], -1) preds = preds.reshape(preds.shape[0], preds.shape[1], -1)
else: else:
@ -486,7 +519,7 @@ model_evaluation.html#multiclass-and-multilabel-classification>`__.
denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator) denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator)
weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights) weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights)
if average not in ["micro", "none", None]: if average not in (AverageMethod.MICRO, AverageMethod.NONE, None):
weights = weights / weights.sum(dim=-1, keepdim=True) weights = weights / weights.sum(dim=-1, keepdim=True)
scores = weights * (numerator / denominator) scores = weights * (numerator / denominator)
@ -494,11 +527,11 @@ model_evaluation.html#multiclass-and-multilabel-classification>`__.
# This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted' # This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted'
scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores) scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores)
if mdmc_average == "samplewise": if mdmc_average == MDMCAverageMethod.SAMPLEWISE:
scores = scores.mean(dim=0) scores = scores.mean(dim=0)
ignore_mask = ignore_mask.sum(dim=0).bool() ignore_mask = ignore_mask.sum(dim=0).bool()
if average in ["none", None]: if average in (AverageMethod.NONE, None):
scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores) scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores)
else: else:
scores = scores.sum() scores = scores.sum()

View File

@ -15,7 +15,7 @@ from typing import Optional, Tuple
import torch import torch
from pytorch_lightning.metrics.classification.helpers import _input_format_classification from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
def _accuracy_update( def _accuracy_update(
@ -24,19 +24,19 @@ def _accuracy_update(
preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k) preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k)
if mode == "multi-label" and top_k: if mode == DataType.MULTILABEL and top_k:
raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.") raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.")
if mode == "binary" or (mode == "multi-label" and subset_accuracy): if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy):
correct = (preds == target).all(dim=1).sum() correct = (preds == target).all(dim=1).sum()
total = torch.tensor(target.shape[0], device=target.device) total = torch.tensor(target.shape[0], device=target.device)
elif mode == "multi-label" and not subset_accuracy: elif mode == DataType.MULTILABEL and not subset_accuracy:
correct = (preds == target).sum() correct = (preds == target).sum()
total = torch.tensor(target.numel(), device=target.device) total = torch.tensor(target.numel(), device=target.device)
elif mode == "multi-class" or (mode == "multi-dim multi-class" and not subset_accuracy): elif mode == DataType.MULTICLASS or (mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy):
correct = (preds * target).sum() correct = (preds * target).sum()
total = target.sum() total = target.sum()
elif mode == "multi-dim multi-class" and subset_accuracy: elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy:
sample_correct = (preds * target).sum(dim=(1, 2)) sample_correct = (preds * target).sum(dim=(1, 2))
correct = (sample_correct == target.shape[2]).sum() correct = (sample_correct == target.shape[2]).sum()
total = torch.tensor(target.shape[0], device=target.device) total = torch.tensor(target.shape[0], device=target.device)

View File

@ -16,7 +16,7 @@ from typing import Optional, Sequence, Tuple
import torch import torch
from pytorch_lightning.metrics.classification.helpers import _input_format_classification from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.functional.auc import auc from pytorch_lightning.metrics.functional.auc import auc
from pytorch_lightning.metrics.functional.roc import roc from pytorch_lightning.metrics.functional.roc import roc
from pytorch_lightning.utilities import LightningEnum from pytorch_lightning.utilities import LightningEnum
@ -102,7 +102,7 @@ def _auroc_compute(
elif average == AverageMethods.MACRO: elif average == AverageMethods.MACRO:
return torch.mean(torch.stack(auc_scores)) return torch.mean(torch.stack(auc_scores))
elif average == AverageMethods.WEIGHTED: elif average == AverageMethods.WEIGHTED:
if mode == 'multi-label': if mode == DataType.MULTILABEL:
support = torch.sum(target, dim=0) support = torch.sum(target, dim=0)
else: else:
support = torch.bincount(target.flatten(), minlength=num_classes) support = torch.bincount(target.flatten(), minlength=num_classes)

View File

@ -15,7 +15,7 @@ from typing import Optional
import torch import torch
from pytorch_lightning.metrics.classification.helpers import _input_format_classification from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities import rank_zero_warn
@ -23,7 +23,7 @@ def _confusion_matrix_update(
preds: torch.Tensor, target: torch.Tensor, num_classes: int, threshold: float = 0.5 preds: torch.Tensor, target: torch.Tensor, num_classes: int, threshold: float = 0.5
) -> torch.Tensor: ) -> torch.Tensor:
preds, target, mode = _input_format_classification(preds, target, threshold) preds, target, mode = _input_format_classification(preds, target, threshold)
if mode not in ('binary', 'multi-label'): if mode not in (DataType.BINARY, DataType.MULTILABEL):
preds = preds.argmax(dim=1) preds = preds.argmax(dim=1)
target = target.argmax(dim=1) target = target.argmax(dim=1)
unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long) unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long)

View File

@ -6,7 +6,7 @@ import torch
from sklearn.metrics import accuracy_score as sk_accuracy from sklearn.metrics import accuracy_score as sk_accuracy
from pytorch_lightning.metrics import Accuracy from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.metrics.classification.helpers import _input_format_classification from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.functional import accuracy from pytorch_lightning.metrics.functional import accuracy
from tests.metrics.classification.inputs import ( from tests.metrics.classification.inputs import (
_binary_inputs, _binary_inputs,
@ -29,12 +29,12 @@ def _sk_accuracy(preds, target, subset_accuracy):
sk_preds, sk_target, mode = _input_format_classification(preds, target, threshold=THRESHOLD) sk_preds, sk_target, mode = _input_format_classification(preds, target, threshold=THRESHOLD)
sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()
if mode == "multi-dim multi-class" and not subset_accuracy: if mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy:
sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1)) sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1))
sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2]) sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2])
elif mode == mode == "multi-dim multi-class" and subset_accuracy: elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy:
return np.all(sk_preds == sk_target, axis=(1, 2)).mean() return np.all(sk_preds == sk_target, axis=(1, 2)).mean()
elif mode == "multi-label" and not subset_accuracy: elif mode == DataType.MULTILABEL and not subset_accuracy:
sk_preds, sk_target = sk_preds.reshape(-1), sk_target.reshape(-1) sk_preds, sk_target = sk_preds.reshape(-1), sk_target.reshape(-1)
return sk_accuracy(y_true=sk_target, y_pred=sk_preds) return sk_accuracy(y_true=sk_target, y_pred=sk_preds)

View File

@ -2,7 +2,7 @@ import pytest
import torch import torch
from torch import rand, randint from torch import rand, randint
from pytorch_lightning.metrics.classification.helpers import _input_format_classification from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.utils import select_topk, to_onehot from pytorch_lightning.metrics.utils import select_topk, to_onehot
from tests.metrics.classification.inputs import _binary_inputs as _bin from tests.metrics.classification.inputs import _binary_inputs as _bin
from tests.metrics.classification.inputs import _binary_prob_inputs as _bin_prob from tests.metrics.classification.inputs import _binary_prob_inputs as _bin_prob
@ -155,32 +155,36 @@ def _mlmd_prob_to_mc_preds_tr(x):
], ],
) )
def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target):
preds_out, target_out, mode = _input_format_classification( def __get_data_type_enum(str_exp_mode):
preds=inputs.preds[0], return next(DataType[n] for n in dir(DataType) if DataType[n] == str_exp_mode)
target=inputs.target[0],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
top_k=top_k,
)
assert mode == exp_mode for exp_mode in (exp_mode, __get_data_type_enum(exp_mode)):
assert torch.equal(preds_out, post_preds(inputs.preds[0]).int()) preds_out, target_out, mode = _input_format_classification(
assert torch.equal(target_out, post_target(inputs.target[0]).int()) preds=inputs.preds[0],
target=inputs.target[0],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
top_k=top_k,
)
# Test that things work when batch_size = 1 assert mode == exp_mode
preds_out, target_out, mode = _input_format_classification( assert torch.equal(preds_out, post_preds(inputs.preds[0]).int())
preds=inputs.preds[0][[0], ...], assert torch.equal(target_out, post_target(inputs.target[0]).int())
target=inputs.target[0][[0], ...],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
top_k=top_k,
)
assert mode == exp_mode # Test that things work when batch_size = 1
assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int()) preds_out, target_out, mode = _input_format_classification(
assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int()) preds=inputs.preds[0][[0], ...],
target=inputs.target[0][[0], ...],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
top_k=top_k,
)
assert mode == exp_mode
assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int())
assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int())
# Test that threshold is correctly applied # Test that threshold is correctly applied