Update metrics to use Enum (#5689)

- Add DataType, AverageMethod and MDMCAverageMethod
This commit is contained in:
yuntai 2021-02-02 03:50:10 +09:00 committed by GitHub
parent 8943d8bca0
commit 50fd4879a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 97 additions and 57 deletions

View File

@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added `DataType`, `AverageMethod` and `MDMCAverageMethod` enum in metrics ([#5657](https://github.com/PyTorchLightning/pytorch-lightning/pull/5689)
- Added support for summarized model total params size in megabytes ([#5590](https://github.com/PyTorchLightning/pytorch-lightning/pull/5590))

View File

@ -17,6 +17,39 @@ import numpy as np
import torch
from pytorch_lightning.metrics.utils import select_topk, to_onehot
from pytorch_lightning.utilities import LightningEnum
class DataType(LightningEnum):
"""
Enum to represent data type
"""
BINARY = "binary"
MULTILABEL = "multi-label"
MULTICLASS = "multi-class"
MULTIDIM_MULTICLASS = "multi-dim multi-class"
class AverageMethod(LightningEnum):
"""
Enum to represent average method
"""
MICRO = "micro"
MACRO = "macro"
WEIGHTED = "weighted"
NONE = "none"
SAMPLES = "samples"
class MDMCAverageMethod(LightningEnum):
"""
Enum to represent multi-dim multi-class average method
"""
GLOBAL = "global"
SAMPLEWISE = "samplewise"
def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool):
@ -78,13 +111,13 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor)
# Get the case
if preds.ndim == 1 and preds_float:
case = "binary"
case = DataType.BINARY
elif preds.ndim == 1 and not preds_float:
case = "multi-class"
case = DataType.MULTICLASS
elif preds.ndim > 1 and preds_float:
case = "multi-label"
case = DataType.MULTILABEL
else:
case = "multi-dim multi-class"
case = DataType.MULTIDIM_MULTICLASS
implied_classes = preds[0].numel()
@ -100,9 +133,9 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor)
implied_classes = preds.shape[1]
if preds.ndim == 2:
case = "multi-class"
case = DataType.MULTICLASS
else:
case = "multi-dim multi-class"
case = DataType.MULTIDIM_MULTICLASS
else:
raise ValueError(
"Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)"
@ -182,7 +215,7 @@ def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes
def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool):
if case == "binary":
if case == DataType.BINARY:
raise ValueError("You can not use `top_k` parameter with binary data.")
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError("The `top_k` has to be an integer larger than 0.")
@ -190,7 +223,7 @@ def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Opt
raise ValueError("You have set `top_k`, but you do not have probability predictions.")
if is_multiclass is False:
raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.")
if case == "multi-label" and is_multiclass:
if case == DataType.MULTILABEL and is_multiclass:
raise ValueError(
"If you want to transform multi-label data to 2 class multi-dimensional"
"multi-class data using `is_multiclass=True`, you can not use `top_k`."
@ -266,7 +299,7 @@ def _check_classification_inputs(
case, implied_classes = _check_shape_and_type_consistency(preds, target)
# For (multi-dim) multi-class case with prob preds, check that preds sum up to 1
if "multi-class" in case and preds.is_floating_point():
if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and preds.is_floating_point():
if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all():
raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.")
@ -284,11 +317,11 @@ def _check_classification_inputs(
# Check that num_classes is consistent
if num_classes:
if case == "binary":
if case == DataType.BINARY:
_check_num_classes_binary(num_classes, is_multiclass)
elif "multi-class" in case:
elif case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS):
_check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes)
elif case == "multi-label":
elif case.MULTILABEL:
_check_num_classes_ml(num_classes, is_multiclass, implied_classes)
# Check that top_k is consistent
@ -406,14 +439,14 @@ def _input_format_classification(
top_k=top_k,
)
if case in ["binary", "multi-label"] and not top_k:
if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k:
preds = (preds >= threshold).int()
num_classes = num_classes if not is_multiclass else 2
if case == "multi-label" and top_k:
if case == DataType.MULTILABEL and top_k:
preds = select_topk(preds, top_k)
if "multi-class" in case or is_multiclass:
if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) or is_multiclass:
if preds.is_floating_point():
num_classes = preds.shape[1]
preds = select_topk(preds, top_k or 1)
@ -426,7 +459,7 @@ def _input_format_classification(
if is_multiclass is False:
preds, target = preds[:, 1, ...], target[:, 1, ...]
if ("multi-class" in case and is_multiclass is not False) or is_multiclass:
if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and is_multiclass is not False) or is_multiclass:
target = target.reshape(target.shape[0], target.shape[1], -1)
preds = preds.reshape(preds.shape[0], preds.shape[1], -1)
else:
@ -486,7 +519,7 @@ model_evaluation.html#multiclass-and-multilabel-classification>`__.
denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator)
weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights)
if average not in ["micro", "none", None]:
if average not in (AverageMethod.MICRO, AverageMethod.NONE, None):
weights = weights / weights.sum(dim=-1, keepdim=True)
scores = weights * (numerator / denominator)
@ -494,11 +527,11 @@ model_evaluation.html#multiclass-and-multilabel-classification>`__.
# This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted'
scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores)
if mdmc_average == "samplewise":
if mdmc_average == MDMCAverageMethod.SAMPLEWISE:
scores = scores.mean(dim=0)
ignore_mask = ignore_mask.sum(dim=0).bool()
if average in ["none", None]:
if average in (AverageMethod.NONE, None):
scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores)
else:
scores = scores.sum()

View File

@ -15,7 +15,7 @@ from typing import Optional, Tuple
import torch
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
def _accuracy_update(
@ -24,19 +24,19 @@ def _accuracy_update(
preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k)
if mode == "multi-label" and top_k:
if mode == DataType.MULTILABEL and top_k:
raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.")
if mode == "binary" or (mode == "multi-label" and subset_accuracy):
if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy):
correct = (preds == target).all(dim=1).sum()
total = torch.tensor(target.shape[0], device=target.device)
elif mode == "multi-label" and not subset_accuracy:
elif mode == DataType.MULTILABEL and not subset_accuracy:
correct = (preds == target).sum()
total = torch.tensor(target.numel(), device=target.device)
elif mode == "multi-class" or (mode == "multi-dim multi-class" and not subset_accuracy):
elif mode == DataType.MULTICLASS or (mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy):
correct = (preds * target).sum()
total = target.sum()
elif mode == "multi-dim multi-class" and subset_accuracy:
elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy:
sample_correct = (preds * target).sum(dim=(1, 2))
correct = (sample_correct == target.shape[2]).sum()
total = torch.tensor(target.shape[0], device=target.device)

View File

@ -16,7 +16,7 @@ from typing import Optional, Sequence, Tuple
import torch
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.functional.auc import auc
from pytorch_lightning.metrics.functional.roc import roc
from pytorch_lightning.utilities import LightningEnum
@ -102,7 +102,7 @@ def _auroc_compute(
elif average == AverageMethods.MACRO:
return torch.mean(torch.stack(auc_scores))
elif average == AverageMethods.WEIGHTED:
if mode == 'multi-label':
if mode == DataType.MULTILABEL:
support = torch.sum(target, dim=0)
else:
support = torch.bincount(target.flatten(), minlength=num_classes)

View File

@ -15,7 +15,7 @@ from typing import Optional
import torch
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.utilities import rank_zero_warn
@ -23,7 +23,7 @@ def _confusion_matrix_update(
preds: torch.Tensor, target: torch.Tensor, num_classes: int, threshold: float = 0.5
) -> torch.Tensor:
preds, target, mode = _input_format_classification(preds, target, threshold)
if mode not in ('binary', 'multi-label'):
if mode not in (DataType.BINARY, DataType.MULTILABEL):
preds = preds.argmax(dim=1)
target = target.argmax(dim=1)
unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long)

View File

@ -6,7 +6,7 @@ import torch
from sklearn.metrics import accuracy_score as sk_accuracy
from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.functional import accuracy
from tests.metrics.classification.inputs import (
_binary_inputs,
@ -29,12 +29,12 @@ def _sk_accuracy(preds, target, subset_accuracy):
sk_preds, sk_target, mode = _input_format_classification(preds, target, threshold=THRESHOLD)
sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy()
if mode == "multi-dim multi-class" and not subset_accuracy:
if mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy:
sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1))
sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2])
elif mode == mode == "multi-dim multi-class" and subset_accuracy:
elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy:
return np.all(sk_preds == sk_target, axis=(1, 2)).mean()
elif mode == "multi-label" and not subset_accuracy:
elif mode == DataType.MULTILABEL and not subset_accuracy:
sk_preds, sk_target = sk_preds.reshape(-1), sk_target.reshape(-1)
return sk_accuracy(y_true=sk_target, y_pred=sk_preds)

View File

@ -2,7 +2,7 @@ import pytest
import torch
from torch import rand, randint
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.utils import select_topk, to_onehot
from tests.metrics.classification.inputs import _binary_inputs as _bin
from tests.metrics.classification.inputs import _binary_prob_inputs as _bin_prob
@ -155,32 +155,36 @@ def _mlmd_prob_to_mc_preds_tr(x):
],
)
def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target):
preds_out, target_out, mode = _input_format_classification(
preds=inputs.preds[0],
target=inputs.target[0],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
top_k=top_k,
)
def __get_data_type_enum(str_exp_mode):
return next(DataType[n] for n in dir(DataType) if DataType[n] == str_exp_mode)
assert mode == exp_mode
assert torch.equal(preds_out, post_preds(inputs.preds[0]).int())
assert torch.equal(target_out, post_target(inputs.target[0]).int())
for exp_mode in (exp_mode, __get_data_type_enum(exp_mode)):
preds_out, target_out, mode = _input_format_classification(
preds=inputs.preds[0],
target=inputs.target[0],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
top_k=top_k,
)
# Test that things work when batch_size = 1
preds_out, target_out, mode = _input_format_classification(
preds=inputs.preds[0][[0], ...],
target=inputs.target[0][[0], ...],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
top_k=top_k,
)
assert mode == exp_mode
assert torch.equal(preds_out, post_preds(inputs.preds[0]).int())
assert torch.equal(target_out, post_target(inputs.target[0]).int())
assert mode == exp_mode
assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int())
assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int())
# Test that things work when batch_size = 1
preds_out, target_out, mode = _input_format_classification(
preds=inputs.preds[0][[0], ...],
target=inputs.target[0][[0], ...],
threshold=THRESHOLD,
num_classes=num_classes,
is_multiclass=is_multiclass,
top_k=top_k,
)
assert mode == exp_mode
assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int())
assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int())
# Test that threshold is correctly applied