diff --git a/CHANGELOG.md b/CHANGELOG.md index cd6f758010..57dc66f0fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `DataType`, `AverageMethod` and `MDMCAverageMethod` enum in metrics ([#5657](https://github.com/PyTorchLightning/pytorch-lightning/pull/5689) + + - Added support for summarized model total params size in megabytes ([#5590](https://github.com/PyTorchLightning/pytorch-lightning/pull/5590)) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index 64cd3389e8..ea6d5722b3 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -17,6 +17,39 @@ import numpy as np import torch from pytorch_lightning.metrics.utils import select_topk, to_onehot +from pytorch_lightning.utilities import LightningEnum + + +class DataType(LightningEnum): + """ + Enum to represent data type + """ + + BINARY = "binary" + MULTILABEL = "multi-label" + MULTICLASS = "multi-class" + MULTIDIM_MULTICLASS = "multi-dim multi-class" + + +class AverageMethod(LightningEnum): + """ + Enum to represent average method + """ + + MICRO = "micro" + MACRO = "macro" + WEIGHTED = "weighted" + NONE = "none" + SAMPLES = "samples" + + +class MDMCAverageMethod(LightningEnum): + """ + Enum to represent multi-dim multi-class average method + """ + + GLOBAL = "global" + SAMPLEWISE = "samplewise" def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool): @@ -78,13 +111,13 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) # Get the case if preds.ndim == 1 and preds_float: - case = "binary" + case = DataType.BINARY elif preds.ndim == 1 and not preds_float: - case = "multi-class" + case = DataType.MULTICLASS elif preds.ndim > 1 and preds_float: - case = "multi-label" + case = DataType.MULTILABEL else: - case = "multi-dim multi-class" + case = DataType.MULTIDIM_MULTICLASS implied_classes = preds[0].numel() @@ -100,9 +133,9 @@ def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) implied_classes = preds.shape[1] if preds.ndim == 2: - case = "multi-class" + case = DataType.MULTICLASS else: - case = "multi-dim multi-class" + case = DataType.MULTIDIM_MULTICLASS else: raise ValueError( "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" @@ -182,7 +215,7 @@ def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool): - if case == "binary": + if case == DataType.BINARY: raise ValueError("You can not use `top_k` parameter with binary data.") if not isinstance(top_k, int) or top_k <= 0: raise ValueError("The `top_k` has to be an integer larger than 0.") @@ -190,7 +223,7 @@ def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Opt raise ValueError("You have set `top_k`, but you do not have probability predictions.") if is_multiclass is False: raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.") - if case == "multi-label" and is_multiclass: + if case == DataType.MULTILABEL and is_multiclass: raise ValueError( "If you want to transform multi-label data to 2 class multi-dimensional" "multi-class data using `is_multiclass=True`, you can not use `top_k`." @@ -266,7 +299,7 @@ def _check_classification_inputs( case, implied_classes = _check_shape_and_type_consistency(preds, target) # For (multi-dim) multi-class case with prob preds, check that preds sum up to 1 - if "multi-class" in case and preds.is_floating_point(): + if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and preds.is_floating_point(): if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all(): raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.") @@ -284,11 +317,11 @@ def _check_classification_inputs( # Check that num_classes is consistent if num_classes: - if case == "binary": + if case == DataType.BINARY: _check_num_classes_binary(num_classes, is_multiclass) - elif "multi-class" in case: + elif case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS): _check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes) - elif case == "multi-label": + elif case.MULTILABEL: _check_num_classes_ml(num_classes, is_multiclass, implied_classes) # Check that top_k is consistent @@ -406,14 +439,14 @@ def _input_format_classification( top_k=top_k, ) - if case in ["binary", "multi-label"] and not top_k: + if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k: preds = (preds >= threshold).int() num_classes = num_classes if not is_multiclass else 2 - if case == "multi-label" and top_k: + if case == DataType.MULTILABEL and top_k: preds = select_topk(preds, top_k) - if "multi-class" in case or is_multiclass: + if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) or is_multiclass: if preds.is_floating_point(): num_classes = preds.shape[1] preds = select_topk(preds, top_k or 1) @@ -426,7 +459,7 @@ def _input_format_classification( if is_multiclass is False: preds, target = preds[:, 1, ...], target[:, 1, ...] - if ("multi-class" in case and is_multiclass is not False) or is_multiclass: + if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and is_multiclass is not False) or is_multiclass: target = target.reshape(target.shape[0], target.shape[1], -1) preds = preds.reshape(preds.shape[0], preds.shape[1], -1) else: @@ -486,7 +519,7 @@ model_evaluation.html#multiclass-and-multilabel-classification>`__. denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator) weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights) - if average not in ["micro", "none", None]: + if average not in (AverageMethod.MICRO, AverageMethod.NONE, None): weights = weights / weights.sum(dim=-1, keepdim=True) scores = weights * (numerator / denominator) @@ -494,11 +527,11 @@ model_evaluation.html#multiclass-and-multilabel-classification>`__. # This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted' scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores) - if mdmc_average == "samplewise": + if mdmc_average == MDMCAverageMethod.SAMPLEWISE: scores = scores.mean(dim=0) ignore_mask = ignore_mask.sum(dim=0).bool() - if average in ["none", None]: + if average in (AverageMethod.NONE, None): scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores) else: scores = scores.sum() diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 5d0bbd691e..b51ce2e678 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -15,7 +15,7 @@ from typing import Optional, Tuple import torch -from pytorch_lightning.metrics.classification.helpers import _input_format_classification +from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType def _accuracy_update( @@ -24,19 +24,19 @@ def _accuracy_update( preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k) - if mode == "multi-label" and top_k: + if mode == DataType.MULTILABEL and top_k: raise ValueError("You can not use the `top_k` parameter to calculate accuracy for multi-label inputs.") - if mode == "binary" or (mode == "multi-label" and subset_accuracy): + if mode == DataType.BINARY or (mode == DataType.MULTILABEL and subset_accuracy): correct = (preds == target).all(dim=1).sum() total = torch.tensor(target.shape[0], device=target.device) - elif mode == "multi-label" and not subset_accuracy: + elif mode == DataType.MULTILABEL and not subset_accuracy: correct = (preds == target).sum() total = torch.tensor(target.numel(), device=target.device) - elif mode == "multi-class" or (mode == "multi-dim multi-class" and not subset_accuracy): + elif mode == DataType.MULTICLASS or (mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy): correct = (preds * target).sum() total = target.sum() - elif mode == "multi-dim multi-class" and subset_accuracy: + elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy: sample_correct = (preds * target).sum(dim=(1, 2)) correct = (sample_correct == target.shape[2]).sum() total = torch.tensor(target.shape[0], device=target.device) diff --git a/pytorch_lightning/metrics/functional/auroc.py b/pytorch_lightning/metrics/functional/auroc.py index fa8b34ea7b..29f5081295 100644 --- a/pytorch_lightning/metrics/functional/auroc.py +++ b/pytorch_lightning/metrics/functional/auroc.py @@ -16,7 +16,7 @@ from typing import Optional, Sequence, Tuple import torch -from pytorch_lightning.metrics.classification.helpers import _input_format_classification +from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType from pytorch_lightning.metrics.functional.auc import auc from pytorch_lightning.metrics.functional.roc import roc from pytorch_lightning.utilities import LightningEnum @@ -102,7 +102,7 @@ def _auroc_compute( elif average == AverageMethods.MACRO: return torch.mean(torch.stack(auc_scores)) elif average == AverageMethods.WEIGHTED: - if mode == 'multi-label': + if mode == DataType.MULTILABEL: support = torch.sum(target, dim=0) else: support = torch.bincount(target.flatten(), minlength=num_classes) diff --git a/pytorch_lightning/metrics/functional/confusion_matrix.py b/pytorch_lightning/metrics/functional/confusion_matrix.py index 1810af5796..a55619dd04 100644 --- a/pytorch_lightning/metrics/functional/confusion_matrix.py +++ b/pytorch_lightning/metrics/functional/confusion_matrix.py @@ -15,7 +15,7 @@ from typing import Optional import torch -from pytorch_lightning.metrics.classification.helpers import _input_format_classification +from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType from pytorch_lightning.utilities import rank_zero_warn @@ -23,7 +23,7 @@ def _confusion_matrix_update( preds: torch.Tensor, target: torch.Tensor, num_classes: int, threshold: float = 0.5 ) -> torch.Tensor: preds, target, mode = _input_format_classification(preds, target, threshold) - if mode not in ('binary', 'multi-label'): + if mode not in (DataType.BINARY, DataType.MULTILABEL): preds = preds.argmax(dim=1) target = target.argmax(dim=1) unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long) diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index eb929caed9..70d05e9499 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -6,7 +6,7 @@ import torch from sklearn.metrics import accuracy_score as sk_accuracy from pytorch_lightning.metrics import Accuracy -from pytorch_lightning.metrics.classification.helpers import _input_format_classification +from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType from pytorch_lightning.metrics.functional import accuracy from tests.metrics.classification.inputs import ( _binary_inputs, @@ -29,12 +29,12 @@ def _sk_accuracy(preds, target, subset_accuracy): sk_preds, sk_target, mode = _input_format_classification(preds, target, threshold=THRESHOLD) sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - if mode == "multi-dim multi-class" and not subset_accuracy: + if mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy: sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1)) sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2]) - elif mode == mode == "multi-dim multi-class" and subset_accuracy: + elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy: return np.all(sk_preds == sk_target, axis=(1, 2)).mean() - elif mode == "multi-label" and not subset_accuracy: + elif mode == DataType.MULTILABEL and not subset_accuracy: sk_preds, sk_target = sk_preds.reshape(-1), sk_target.reshape(-1) return sk_accuracy(y_true=sk_target, y_pred=sk_preds) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 8c8c6b9033..bcbe9c3bd5 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -2,7 +2,7 @@ import pytest import torch from torch import rand, randint -from pytorch_lightning.metrics.classification.helpers import _input_format_classification +from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType from pytorch_lightning.metrics.utils import select_topk, to_onehot from tests.metrics.classification.inputs import _binary_inputs as _bin from tests.metrics.classification.inputs import _binary_prob_inputs as _bin_prob @@ -155,32 +155,36 @@ def _mlmd_prob_to_mc_preds_tr(x): ], ) def test_usual_cases(inputs, num_classes, is_multiclass, top_k, exp_mode, post_preds, post_target): - preds_out, target_out, mode = _input_format_classification( - preds=inputs.preds[0], - target=inputs.target[0], - threshold=THRESHOLD, - num_classes=num_classes, - is_multiclass=is_multiclass, - top_k=top_k, - ) + def __get_data_type_enum(str_exp_mode): + return next(DataType[n] for n in dir(DataType) if DataType[n] == str_exp_mode) - assert mode == exp_mode - assert torch.equal(preds_out, post_preds(inputs.preds[0]).int()) - assert torch.equal(target_out, post_target(inputs.target[0]).int()) + for exp_mode in (exp_mode, __get_data_type_enum(exp_mode)): + preds_out, target_out, mode = _input_format_classification( + preds=inputs.preds[0], + target=inputs.target[0], + threshold=THRESHOLD, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) - # Test that things work when batch_size = 1 - preds_out, target_out, mode = _input_format_classification( - preds=inputs.preds[0][[0], ...], - target=inputs.target[0][[0], ...], - threshold=THRESHOLD, - num_classes=num_classes, - is_multiclass=is_multiclass, - top_k=top_k, - ) + assert mode == exp_mode + assert torch.equal(preds_out, post_preds(inputs.preds[0]).int()) + assert torch.equal(target_out, post_target(inputs.target[0]).int()) - assert mode == exp_mode - assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int()) - assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int()) + # Test that things work when batch_size = 1 + preds_out, target_out, mode = _input_format_classification( + preds=inputs.preds[0][[0], ...], + target=inputs.target[0][[0], ...], + threshold=THRESHOLD, + num_classes=num_classes, + is_multiclass=is_multiclass, + top_k=top_k, + ) + + assert mode == exp_mode + assert torch.equal(preds_out, post_preds(inputs.preds[0][[0], ...]).int()) + assert torch.equal(target_out, post_target(inputs.target[0][[0], ...]).int()) # Test that threshold is correctly applied