From b3d25262e98cd3098a47faecc92053bff8944323 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 7 Oct 2020 18:48:58 +0200 Subject: [PATCH] restore functional metrics (#3943) * restore functional metrics * clean * fix --- .../metrics/functional/__init__.py | 31 + .../metrics/functional/classification.py | 964 ++++++++++++++++++ pytorch_lightning/metrics/functional/nlp.py | 92 ++ .../metrics/functional/reduction.py | 24 + .../metrics/functional/regression.py | 297 ++++++ tests/metrics/functional/__init__.py | 0 .../metrics/functional/test_classification.py | 341 +++++++ tests/metrics/functional/test_nlp.py | 66 ++ tests/metrics/functional/test_reduction.py | 15 + tests/metrics/functional/test_regression.py | 144 +++ 10 files changed, 1974 insertions(+) create mode 100644 pytorch_lightning/metrics/functional/__init__.py create mode 100644 pytorch_lightning/metrics/functional/classification.py create mode 100644 pytorch_lightning/metrics/functional/nlp.py create mode 100644 pytorch_lightning/metrics/functional/reduction.py create mode 100644 pytorch_lightning/metrics/functional/regression.py create mode 100644 tests/metrics/functional/__init__.py create mode 100644 tests/metrics/functional/test_classification.py create mode 100644 tests/metrics/functional/test_nlp.py create mode 100644 tests/metrics/functional/test_reduction.py create mode 100644 tests/metrics/functional/test_regression.py diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py new file mode 100644 index 0000000000..926803b504 --- /dev/null +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -0,0 +1,31 @@ +from pytorch_lightning.metrics.functional.classification import ( + accuracy, + auc, + auroc, + average_precision, + confusion_matrix, + dice_score, + f1_score, + fbeta_score, + multiclass_precision_recall_curve, + multiclass_roc, + precision, + precision_recall, + precision_recall_curve, + recall, + roc, + stat_scores, + stat_scores_multiple_classes, + to_categorical, + to_onehot, + iou, +) +from pytorch_lightning.metrics.functional.nlp import bleu_score +from pytorch_lightning.metrics.functional.regression import ( + mae, + mse, + psnr, + rmse, + rmsle, + ssim +) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py new file mode 100644 index 0000000000..b6acf05a64 --- /dev/null +++ b/pytorch_lightning/metrics/functional/classification.py @@ -0,0 +1,964 @@ +from collections import Sequence +from functools import wraps +from typing import Optional, Tuple, Callable + +import torch +from torch.nn import functional as F + +from pytorch_lightning.metrics.functional.reduction import reduce +from pytorch_lightning.utilities import rank_zero_warn, FLOAT16_EPSILON + + +def to_onehot( + tensor: torch.Tensor, + num_classes: Optional[int] = None, +) -> torch.Tensor: + """ + Converts a dense label tensor to one-hot format + + Args: + tensor: dense label tensor, with shape [N, d1, d2, ...] + num_classes: number of classes C + + Output: + A sparse label tensor with shape [N, C, d1, d2, ...] + + Example: + + >>> x = torch.tensor([1, 2, 3]) + >>> to_onehot(x) + tensor([[0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + + """ + if num_classes is None: + num_classes = int(tensor.max().detach().item() + 1) + dtype, device, shape = tensor.dtype, tensor.device, tensor.shape + tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], + dtype=dtype, device=device) + index = tensor.long().unsqueeze(1).expand_as(tensor_onehot) + return tensor_onehot.scatter_(1, index, 1.0) + + +def to_categorical( + tensor: torch.Tensor, + argmax_dim: int = 1 +) -> torch.Tensor: + """ + Converts a tensor of probabilities to a dense label tensor + + Args: + tensor: probabilities to get the categorical label [N, d1, d2, ...] + argmax_dim: dimension to apply + + Return: + A tensor with categorical labels [N, d2, ...] + + Example: + + >>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) + >>> to_categorical(x) + tensor([1, 0]) + + """ + return torch.argmax(tensor, dim=argmax_dim) + + +def get_num_classes( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, +) -> int: + """ + Calculates the number of classes for a given prediction and target tensor. + + Args: + pred: predicted values + target: true labels + num_classes: number of classes if known + + Return: + An integer that represents the number of classes. + """ + num_target_classes = int(target.max().detach().item() + 1) + num_pred_classes = int(pred.max().detach().item() + 1) + num_all_classes = max(num_target_classes, num_pred_classes) + + if num_classes is None: + num_classes = num_all_classes + elif num_classes != num_all_classes: + rank_zero_warn(f'You have set {num_classes} number of classes if different from' + f' predicted ({num_pred_classes}) and target ({num_target_classes}) number of classes') + return num_classes + + +def stat_scores( + pred: torch.Tensor, + target: torch.Tensor, + class_index: int, argmax_dim: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Calculates the number of true positive, false positive, true negative + and false negative for a specific class + + Args: + pred: prediction tensor + target: target tensor + class_index: class to calculate over + argmax_dim: if pred is a tensor of probabilities, this indicates the + axis the argmax transformation will be applied over + + Return: + True Positive, False Positive, True Negative, False Negative, Support + + Example: + + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([0, 2, 3]) + >>> tp, fp, tn, fn, sup = stat_scores(x, y, class_index=1) + >>> tp, fp, tn, fn, sup + (tensor(0), tensor(1), tensor(2), tensor(0), tensor(0)) + + """ + if pred.ndim == target.ndim + 1: + pred = to_categorical(pred, argmax_dim=argmax_dim) + + tp = ((pred == class_index) * (target == class_index)).to(torch.long).sum() + fp = ((pred == class_index) * (target != class_index)).to(torch.long).sum() + tn = ((pred != class_index) * (target != class_index)).to(torch.long).sum() + fn = ((pred != class_index) * (target == class_index)).to(torch.long).sum() + sup = (target == class_index).to(torch.long).sum() + + return tp, fp, tn, fn, sup + + +def stat_scores_multiple_classes( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + argmax_dim: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Calls the stat_scores function iteratively for all classes, thus + calculating the number of true postive, false postive, true negative + and false negative for each class + + Args: + pred: prediction tensor + target: target tensor + num_classes: number of classes if known + argmax_dim: if pred is a tensor of probabilities, this indicates the + axis the argmax transformation will be applied over + + Return: + True Positive, False Positive, True Negative, False Negative, Support + + Example: + + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([0, 2, 3]) + >>> tps, fps, tns, fns, sups = stat_scores_multiple_classes(x, y) + >>> tps + tensor([0., 0., 1., 1.]) + >>> fps + tensor([0., 1., 0., 0.]) + >>> tns + tensor([2., 2., 2., 2.]) + >>> fns + tensor([1., 0., 0., 0.]) + >>> sups + tensor([1., 0., 1., 1.]) + """ + num_classes = get_num_classes(pred=pred, target=target, + num_classes=num_classes) + + if pred.ndim == target.ndim + 1: + pred = to_categorical(pred, argmax_dim=argmax_dim) + + tps = torch.zeros((num_classes,), device=pred.device) + fps = torch.zeros((num_classes,), device=pred.device) + tns = torch.zeros((num_classes,), device=pred.device) + fns = torch.zeros((num_classes,), device=pred.device) + sups = torch.zeros((num_classes,), device=pred.device) + for c in range(num_classes): + tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred, target=target, class_index=c) + + return tps, fps, tns, fns, sups + + +def accuracy( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + reduction='elementwise_mean', +) -> torch.Tensor: + """ + Computes the accuracy classification score + + Args: + pred: predicted labels + target: ground truth labels + num_classes: number of classes + reduction: a method for reducing accuracies over labels (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements + + Return: + A Tensor with the classification score. + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> accuracy(x, y) + tensor(0.7500) + + """ + tps, fps, tns, fns, sups = stat_scores_multiple_classes( + pred=pred, target=target, num_classes=num_classes) + + if not (target > 0).any() and num_classes is None: + raise RuntimeError("cannot infer num_classes when target is all zero") + + if reduction in ('elementwise_mean', 'sum'): + return reduce(sum(tps) / sum(sups), reduction=reduction) + if reduction == 'none': + return reduce(tps / sups, reduction=reduction) + + +def confusion_matrix( + pred: torch.Tensor, + target: torch.Tensor, + normalize: bool = False, +) -> torch.Tensor: + """ + Computes the confusion matrix C where each entry C_{i,j} is the number of observations + in group i that were predicted in group j. + + Args: + pred: estimated targets + target: ground truth labels + normalize: normalizes confusion matrix + + Return: + Tensor, confusion matrix C [num_classes, num_classes ] + + Example: + + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([0, 2, 3]) + >>> confusion_matrix(x, y) + tensor([[0., 1., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 1., 0.], + [0., 0., 0., 1.]]) + """ + num_classes = get_num_classes(pred, target, None) + + unique_labels = target.view(-1) * num_classes + pred.view(-1) + + bins = torch.bincount(unique_labels, minlength=num_classes ** 2) + cm = bins.reshape(num_classes, num_classes).squeeze().float() + + if normalize: + cm = cm / cm.sum(-1) + + return cm + + +def precision_recall( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + reduction: str = 'elementwise_mean', +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes precision and recall for different thresholds + + Args: + pred: estimated probabilities + target: ground-truth labels + num_classes: number of classes + reduction: method for reducing precision-recall values (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements + + Return: + Tensor with precision and recall + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> precision_recall(x, y) + (tensor(0.7500), tensor(0.6250)) + + """ + tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes) + + tps = tps.to(torch.float) + fps = fps.to(torch.float) + fns = fns.to(torch.float) + + precision = tps / (tps + fps) + recall = tps / (tps + fns) + + # solution by justus, see https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/9 + precision[precision != precision] = 0 + recall[recall != recall] = 0 + + precision = reduce(precision, reduction=reduction) + recall = reduce(recall, reduction=reduction) + return precision, recall + + +def precision( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + reduction: str = 'elementwise_mean', +) -> torch.Tensor: + """ + Computes precision score. + + Args: + pred: estimated probabilities + target: ground-truth labels + num_classes: number of classes + reduction: method for reducing precision values (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements + + Return: + Tensor with precision. + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> precision(x, y) + tensor(0.7500) + + """ + return precision_recall(pred=pred, target=target, + num_classes=num_classes, reduction=reduction)[0] + + +def recall( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + reduction: str = 'elementwise_mean', +) -> torch.Tensor: + """ + Computes recall score. + + Args: + pred: estimated probabilities + target: ground-truth labels + num_classes: number of classes + reduction: method for reducing recall values (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements + + Return: + Tensor with recall. + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> recall(x, y) + tensor(0.6250) + """ + return precision_recall(pred=pred, target=target, + num_classes=num_classes, reduction=reduction)[1] + + +def fbeta_score( + pred: torch.Tensor, + target: torch.Tensor, + beta: float, + num_classes: Optional[int] = None, + reduction: str = 'elementwise_mean', +) -> torch.Tensor: + """ + Computes the F-beta score which is a weighted harmonic mean of precision and recall. + It ranges between 1 and 0, where 1 is perfect and the worst value is 0. + + Args: + pred: estimated probabilities + target: ground-truth labels + beta: weights recall when combining the score. + beta < 1: more weight to precision. + beta > 1 more weight to recall + beta = 0: only precision + beta -> inf: only recall + num_classes: number of classes + reduction: method for reducing F-score (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements. + + Return: + Tensor with the value of F-score. It is a value between 0-1. + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> fbeta_score(x, y, 0.2) + tensor(0.7407) + """ + prec, rec = precision_recall(pred=pred, target=target, + num_classes=num_classes, + reduction='none') + + nom = (1 + beta ** 2) * prec * rec + denom = ((beta ** 2) * prec + rec) + fbeta = nom / denom + + # drop NaN after zero division + fbeta[fbeta != fbeta] = 0 + + return reduce(fbeta, reduction=reduction) + + +def f1_score( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + reduction='elementwise_mean', +) -> torch.Tensor: + """ + Computes the F1-score (a.k.a F-measure), which is the harmonic mean of the precision and recall. + It ranges between 1 and 0, where 1 is perfect and the worst value is 0. + + Args: + pred: estimated probabilities + target: ground-truth labels + num_classes: number of classes + reduction: method for reducing F1-score (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements. + + Return: + Tensor containing F1-score + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> f1_score(x, y) + tensor(0.6667) + """ + return fbeta_score(pred=pred, target=target, beta=1., + num_classes=num_classes, reduction=reduction) + + +def _binary_clf_curve( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py + """ + if sample_weight is not None and not isinstance(sample_weight, torch.Tensor): + sample_weight = torch.tensor(sample_weight, device=pred.device, dtype=torch.float) + + # remove class dimension if necessary + if pred.ndim > target.ndim: + pred = pred[:, 0] + desc_score_indices = torch.argsort(pred, descending=True) + + pred = pred[desc_score_indices] + target = target[desc_score_indices] + + if sample_weight is not None: + weight = sample_weight[desc_score_indices] + else: + weight = 1. + + # pred typically has many tied values. Here we extract + # the indices associated with the distinct values. We also + # concatenate a value for the end of the curve. + distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0] + threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1) + + target = (target == pos_label).to(torch.long) + tps = torch.cumsum(target * weight, dim=0)[threshold_idxs] + + if sample_weight is not None: + # express fps as a cumsum to ensure fps is increasing even in + # the presence of floating point errors + fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs] + else: + fps = 1 + threshold_idxs - tps + + return fps, tps, pred[threshold_idxs] + + +def roc( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. + + Args: + pred: estimated probabilities + target: ground-truth labels + sample_weight: sample weights + pos_label: the label for the positive class + + Return: + false-positive rate (fpr), true-positive rate (tpr), thresholds + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> fpr, tpr, thresholds = roc(x, y) + >>> fpr + tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]) + >>> tpr + tensor([0., 0., 0., 1., 1.]) + >>> thresholds + tensor([4, 3, 2, 1, 0]) + + """ + fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, + sample_weight=sample_weight, + pos_label=pos_label) + + # Add an extra threshold position + # to make sure that the curve starts at (0, 0) + tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps]) + fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps]) + thresholds = torch.cat([thresholds[0][None] + 1, thresholds]) + + if fps[-1] <= 0: + raise ValueError("No negative samples in targets, false positive value should be meaningless") + + fpr = fps / fps[-1] + + if tps[-1] <= 0: + raise ValueError("No positive samples in targets, true positive value should be meaningless") + + tpr = tps / tps[-1] + + return fpr, tpr, thresholds + + +def multiclass_roc( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + num_classes: Optional[int] = None, +) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """ + Computes the Receiver Operating Characteristic (ROC) for multiclass predictors. + + Args: + pred: estimated probabilities + target: ground-truth labels + sample_weight: sample weights + num_classes: number of classes (default: None, computes automatically from data) + + Return: + returns roc for each class. + Number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds + + Example: + + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE + ((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), + (tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), + (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])), + (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500]))) + """ + num_classes = get_num_classes(pred, target, num_classes) + + class_roc_vals = [] + for c in range(num_classes): + pred_c = pred[:, c] + + class_roc_vals.append(roc(pred=pred_c, target=target, + sample_weight=sample_weight, pos_label=c)) + + return tuple(class_roc_vals) + + +def precision_recall_curve( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes precision-recall pairs for different thresholds. + + Args: + pred: estimated probabilities + target: ground-truth labels + sample_weight: sample weights + pos_label: the label for the positive class + + Return: + precision, recall, thresholds + + Example: + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 2, 2]) + >>> precision, recall, thresholds = precision_recall_curve(pred, target) + >>> precision + tensor([0.3333, 0.0000, 0.0000, 1.0000]) + >>> recall + tensor([1., 0., 0., 0.]) + >>> thresholds + tensor([1, 2, 3]) + + """ + fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, + sample_weight=sample_weight, + pos_label=pos_label) + + precision = tps / (tps + fps) + recall = tps / tps[-1] + + # stop when full recall attained + # and reverse the outputs so recall is decreasing + last_ind = torch.where(tps == tps[-1])[0][0] + sl = slice(0, last_ind.item() + 1) + + # need to call reversed explicitly, since including that to slice would + # introduce negative strides that are not yet supported in pytorch + precision = torch.cat([reversed(precision[sl]), + torch.ones(1, dtype=precision.dtype, + device=precision.device)]) + + recall = torch.cat([reversed(recall[sl]), + torch.zeros(1, dtype=recall.dtype, + device=recall.device)]) + + thresholds = torch.tensor(reversed(thresholds[sl])) + + return precision, recall, thresholds + + +def multiclass_precision_recall_curve( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + num_classes: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes precision-recall pairs for different thresholds given a multiclass scores. + + Args: + pred: estimated probabilities + target: ground-truth labels + sample_weight: sample weight + num_classes: number of classes + + Return: + number of classes, precision, recall, thresholds + + Example: + + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> nb_classes, precision, recall, thresholds = multiclass_precision_recall_curve(pred, target) + >>> nb_classes + (tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])) + >>> precision + (tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])) + >>> recall + (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])) + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])) + """ + num_classes = get_num_classes(pred, target, num_classes) + + class_pr_vals = [] + for c in range(num_classes): + pred_c = pred[:, c] + + class_pr_vals.append(precision_recall_curve( + pred=pred_c, + target=target, + sample_weight=sample_weight, pos_label=c)) + + return tuple(class_pr_vals) + + +def auc( + x: torch.Tensor, + y: torch.Tensor, + reorder: bool = True +) -> torch.Tensor: + """ + Computes Area Under the Curve (AUC) using the trapezoidal rule + + Args: + x: x-coordinates + y: y-coordinates + reorder: reorder coordinates, so they are increasing + + Return: + Tensor containing AUC score (float) + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> auc(x, y) + tensor(4.) + """ + direction = 1. + + if reorder: + # can't use lexsort here since it is not implemented for torch + order = torch.argsort(x) + x, y = x[order], y[order] + else: + dx = x[1:] - x[:-1] + if (dx < 0).any(): + if (dx, 0).all(): + direction = -1. + else: + raise ValueError("Reordering is not turned on, and " + "the x array is not increasing: %s" % x) + + return direction * torch.trapz(y, x) + + +def auc_decorator(reorder: bool = True) -> Callable: + def wrapper(func_to_decorate: Callable) -> Callable: + @wraps(func_to_decorate) + def new_func(*args, **kwargs) -> torch.Tensor: + x, y = func_to_decorate(*args, **kwargs)[:2] + + return auc(x, y, reorder=reorder) + + return new_func + + return wrapper + + +def multiclass_auc_decorator(reorder: bool = True) -> Callable: + def wrapper(func_to_decorate: Callable) -> Callable: + def new_func(*args, **kwargs) -> torch.Tensor: + results = [] + for class_result in func_to_decorate(*args, **kwargs): + x, y = class_result[:2] + results.append(auc(x, y, reorder=reorder)) + + return torch.cat(results) + + return new_func + + return wrapper + + +def auroc( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., +) -> torch.Tensor: + """ + Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores + + Args: + pred: estimated probabilities + target: ground-truth labels + sample_weight: sample weights + pos_label: the label for the positive class + + Return: + Tensor containing ROCAUC score + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> auroc(x, y) + tensor(0.3333) + """ + + @auc_decorator(reorder=True) + def _auroc(pred, target, sample_weight, pos_label): + return roc(pred, target, sample_weight, pos_label) + + return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) + + +def average_precision( + pred: torch.Tensor, + target: torch.Tensor, + sample_weight: Optional[Sequence] = None, + pos_label: int = 1., +) -> torch.Tensor: + """ + Compute average precision from prediction scores + + Args: + pred: estimated probabilities + target: ground-truth labels + sample_weight: sample weights + pos_label: the label for the positive class + + Return: + Tensor containing average precision score + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> average_precision(x, y) + tensor(0.3333) + """ + precision, recall, _ = precision_recall_curve(pred=pred, target=target, + sample_weight=sample_weight, + pos_label=pos_label) + # Return the step function integral + # The following works because the last entry of precision is + # guaranteed to be 1, as returned by precision_recall_curve + return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1]) + + +def dice_score( + pred: torch.Tensor, + target: torch.Tensor, + bg: bool = False, + nan_score: float = 0.0, + no_fg_score: float = 0.0, + reduction: str = 'elementwise_mean', +) -> torch.Tensor: + """ + Compute dice score from prediction scores + + Args: + pred: estimated probabilities + target: ground-truth labels + bg: whether to also compute dice for the background + nan_score: score to return, if a NaN occurs during computation + no_fg_score: score to return, if no foreground pixel was found in target + reduction: a method for reducing accuracies over labels (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements + + Return: + Tensor containing dice score + + Example: + + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> dice_score(pred, target) + tensor(0.3333) + + """ + num_classes = pred.shape[1] + bg = (1 - int(bool(bg))) + scores = torch.zeros(num_classes - bg, device=pred.device, dtype=torch.float32) + for i in range(bg, num_classes): + if not (target == i).any(): + # no foreground class + scores[i - bg] += no_fg_score + continue + + tp, fp, tn, fn, sup = stat_scores(pred=pred, target=target, class_index=i) + denom = (2 * tp + fp + fn).to(torch.float) + # nan result + score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else nan_score + + scores[i - bg] += score_cls + return reduce(scores, reduction=reduction) + + +def iou( + pred: torch.Tensor, + target: torch.Tensor, + num_classes: Optional[int] = None, + remove_bg: bool = False, + reduction: str = 'elementwise_mean' +) -> torch.Tensor: + """ + Intersection over union, or Jaccard index calculation. + + Args: + pred: Tensor containing predictions + target: Tensor containing targets + num_classes: Optionally specify the number of classes + remove_bg: Flag to state whether a background class has been included + within input parameters. If true, will remove background class. If + false, return IoU over all classes + Assumes that background is '0' class in input tensor + reduction: a method for reducing IoU over labels (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements + + Return: + IoU score : Tensor containing single value if reduction is + 'elementwise_mean', or number of classes if reduction is 'none' + + Example: + + >>> target = torch.randint(0, 1, (10, 25, 25)) + >>> pred = torch.tensor(target) + >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] + >>> iou(pred, target) + tensor(0.4914) + + """ + tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred, target, num_classes) + if remove_bg: + tps = tps[1:] + fps = fps[1:] + fns = fns[1:] + denom = fps + fns + tps + denom[denom == 0] = torch.tensor(FLOAT16_EPSILON).type_as(denom) + iou = tps / denom + return reduce(iou, reduction=reduction) diff --git a/pytorch_lightning/metrics/functional/nlp.py b/pytorch_lightning/metrics/functional/nlp.py new file mode 100644 index 0000000000..22645bb549 --- /dev/null +++ b/pytorch_lightning/metrics/functional/nlp.py @@ -0,0 +1,92 @@ +# referenced from +# Library Name: torchtext +# Authors: torchtext authors and @sluks +# Date: 2020-07-18 +# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score +from collections import Counter +from typing import Sequence, List + +import torch + + +def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: + """Counting how many times each word appears in a given text with ngram + + Args: + ngram_input_list: A list of translated text or reference texts + n_gram: gram value ranged 1 to 4 + + Return: + ngram_counter: a collections.Counter object of ngram + """ + + ngram_counter = Counter() + + for i in range(1, n_gram + 1): + for j in range(len(ngram_input_list) - i + 1): + ngram_key = tuple(ngram_input_list[j : i + j]) + ngram_counter[ngram_key] += 1 + + return ngram_counter + + +def bleu_score( + translate_corpus: Sequence[str], reference_corpus: Sequence[str], n_gram: int = 4, smooth: bool = False +) -> torch.Tensor: + """Calculate BLEU score of machine translated text with one or more references. + + Args: + translate_corpus: An iterable of machine translated corpus + reference_corpus: An iterable of iterables of reference corpus + n_gram: Gram value ranged from 1 to 4 (Default 4) + smooth: Whether or not to apply smoothing – Lin et al. 2004 + + Return: + A Tensor with BLEU Score + + Example: + + >>> translate_corpus = ['the cat is on the mat'.split()] + >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] + >>> bleu_score(translate_corpus, reference_corpus) + tensor(0.7598) + """ + + assert len(translate_corpus) == len(reference_corpus) + numerator = torch.zeros(n_gram) + denominator = torch.zeros(n_gram) + precision_scores = torch.zeros(n_gram) + c = 0.0 + r = 0.0 + for (translation, references) in zip(translate_corpus, reference_corpus): + c += len(translation) + ref_len_list = [len(ref) for ref in references] + ref_len_diff = [abs(len(translation) - x) for x in ref_len_list] + r += ref_len_list[ref_len_diff.index(min(ref_len_diff))] + translation_counter = _count_ngram(translation, n_gram) + reference_counter = Counter() + for ref in references: + reference_counter |= _count_ngram(ref, n_gram) + + ngram_counter_clip = translation_counter & reference_counter + for counter_clip in ngram_counter_clip: + numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip] + + for counter in translation_counter: + denominator[len(counter) - 1] += translation_counter[counter] + + trans_len = torch.tensor(c) + ref_len = torch.tensor(r) + if min(numerator) == 0.0: + return torch.tensor(0.0) + + if smooth: + precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram)) + else: + precision_scores = numerator / denominator + log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores) + geometric_mean = torch.exp(torch.sum(log_precision_scores)) + brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len)) + bleu = brevity_penalty * geometric_mean + + return bleu diff --git a/pytorch_lightning/metrics/functional/reduction.py b/pytorch_lightning/metrics/functional/reduction.py new file mode 100644 index 0000000000..b9be8ca7da --- /dev/null +++ b/pytorch_lightning/metrics/functional/reduction.py @@ -0,0 +1,24 @@ +import torch + + +def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: + """ + Reduces a given tensor by a given reduction method + + Args: + to_reduce : the tensor, which shall be reduced + reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum') + + Return: + reduced Tensor + + Raise: + ValueError if an invalid reduction parameter was given + """ + if reduction == 'elementwise_mean': + return torch.mean(to_reduce) + if reduction == 'none': + return to_reduce + if reduction == 'sum': + return torch.sum(to_reduce) + raise ValueError('Reduction parameter unknown.') diff --git a/pytorch_lightning/metrics/functional/regression.py b/pytorch_lightning/metrics/functional/regression.py new file mode 100644 index 0000000000..6ad5ee6cfb --- /dev/null +++ b/pytorch_lightning/metrics/functional/regression.py @@ -0,0 +1,297 @@ +from typing import Sequence + +import torch +from torch.nn import functional as F + +from pytorch_lightning.metrics.functional.reduction import reduce + + +def mse( + pred: torch.Tensor, + target: torch.Tensor, + reduction: str = 'elementwise_mean' +) -> torch.Tensor: + """ + Computes mean squared error + + Args: + pred: estimated labels + target: ground truth labels + reduction: method for reducing mse (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements + + Return: + Tensor with MSE + + Example: + + >>> x = torch.tensor([0., 1, 2, 3]) + >>> y = torch.tensor([0., 1, 2, 2]) + >>> mse(x, y) + tensor(0.2500) + + """ + mse = F.mse_loss(pred, target, reduction='none') + mse = reduce(mse, reduction=reduction) + return mse + + +def rmse( + pred: torch.Tensor, + target: torch.Tensor, + reduction: str = 'elementwise_mean' +) -> torch.Tensor: + """ + Computes root mean squared error + + Args: + pred: estimated labels + target: ground truth labels + reduction: method for reducing rmse (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements + + Return: + Tensor with RMSE + + + >>> x = torch.tensor([0., 1, 2, 3]) + >>> y = torch.tensor([0., 1, 2, 2]) + >>> rmse(x, y) + tensor(0.5000) + + """ + rmse = torch.sqrt(mse(pred, target, reduction=reduction)) + return rmse + + +def mae( + pred: torch.Tensor, + target: torch.Tensor, + reduction: str = 'elementwise_mean' +) -> torch.Tensor: + """ + Computes mean absolute error + + Args: + pred: estimated labels + target: ground truth labels + reduction: method for reducing mae (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements + + Return: + Tensor with MAE + + Example: + + >>> x = torch.tensor([0., 1, 2, 3]) + >>> y = torch.tensor([0., 1, 2, 2]) + >>> mae(x, y) + tensor(0.2500) + + """ + mae = F.l1_loss(pred, target, reduction='none') + mae = reduce(mae, reduction=reduction) + return mae + + +def rmsle( + pred: torch.Tensor, + target: torch.Tensor, + reduction: str = 'elementwise_mean' +) -> torch.Tensor: + """ + Computes root mean squared log error + + Args: + pred: estimated labels + target: ground truth labels + reduction: method for reducing rmsle (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements + + Return: + Tensor with RMSLE + + Example: + + >>> x = torch.tensor([0., 1, 2, 3]) + >>> y = torch.tensor([0., 1, 2, 2]) + >>> rmsle(x, y) + tensor(0.0207) + + """ + rmsle = mse(torch.log(pred + 1), torch.log(target + 1), reduction=reduction) + return rmsle + + +def psnr( + pred: torch.Tensor, + target: torch.Tensor, + data_range: float = None, + base: float = 10.0, + reduction: str = 'elementwise_mean' +) -> torch.Tensor: + """ + Computes the peak signal-to-noise ratio + + Args: + pred: estimated signal + target: groun truth signal + data_range: the range of the data. If None, it is determined from the data (max - min) + base: a base of a logarithm to use (default: 10) + reduction: method for reducing psnr (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum add elements + + Return: + Tensor with PSNR score + + Example: + + >>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) + >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) + >>> psnr(pred, target) + tensor(2.5527) + + """ + + if data_range is None: + data_range = max(target.max() - target.min(), pred.max() - pred.min()) + else: + data_range = torch.tensor(float(data_range)) + + mse_score = mse(pred.view(-1), target.view(-1), reduction=reduction) + psnr_base_e = 2 * torch.log(data_range) - torch.log(mse_score) + psnr = psnr_base_e * (10 / torch.log(torch.tensor(base))) + return psnr + + +def _gaussian_kernel(channel, kernel_size, sigma, device): + def gaussian(kernel_size, sigma, device): + gauss = torch.arange( + start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=torch.float32, device=device + ) + gauss = torch.exp(-gauss.pow(2) / (2 * pow(sigma, 2))) + return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) + + gaussian_kernel_x = gaussian(kernel_size[0], sigma[0], device) + gaussian_kernel_y = gaussian(kernel_size[1], sigma[1], device) + kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) + + return kernel.expand(channel, 1, kernel_size[0], kernel_size[1]) + + +def ssim( + pred: torch.Tensor, + target: torch.Tensor, + kernel_size: Sequence[int] = (11, 11), + sigma: Sequence[float] = (1.5, 1.5), + reduction: str = "elementwise_mean", + data_range: float = None, + k1: float = 0.01, + k2: float = 0.03 +) -> torch.Tensor: + """ + Computes Structual Similarity Index Measure + + Args: + pred: Estimated image + target: Ground truth image + kernel_size: Size of the gaussian kernel. Default: (11, 11) + sigma: Standard deviation of the gaussian kernel. Default: (1.5, 1.5) + reduction: A method for reducing ssim over all elements in the ``pred`` tensor. Default: ``elementwise_mean`` + + Available reduction methods: + - elementwise_mean: takes the mean + - none: pass away + - sum: add elements + + data_range: Range of the image. If ``None``, it is determined from the image (max - min) + k1: Parameter of SSIM. Default: 0.01 + k2: Parameter of SSIM. Default: 0.03 + + Returns: + A Tensor with SSIM + + Example: + + >>> pred = torch.rand([16, 1, 16, 16]) + >>> target = pred * 1.25 + >>> ssim(pred, target) + tensor(0.9520) + """ + + if pred.dtype != target.dtype: + raise TypeError( + "Expected `pred` and `target` to have the same data type." + f" Got pred: {pred.dtype} and target: {target.dtype}." + ) + + if pred.shape != target.shape: + raise ValueError( + "Expected `pred` and `target` to have the same shape." + f" Got pred: {pred.shape} and target: {target.shape}." + ) + + if len(pred.shape) != 4 or len(target.shape) != 4: + raise ValueError( + "Expected `pred` and `target` to have BxCxHxW shape." + f" Got pred: {pred.shape} and target: {target.shape}." + ) + + if len(kernel_size) != 2 or len(sigma) != 2: + raise ValueError( + "Expected `kernel_size` and `sigma` to have the length of two." + f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}." + ) + + if any(x % 2 == 0 or x <= 0 for x in kernel_size): + raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.") + + if any(y <= 0 for y in sigma): + raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.") + + if data_range is None: + data_range = max(pred.max() - pred.min(), target.max() - target.min()) + + C1 = pow(k1 * data_range, 2) + C2 = pow(k2 * data_range, 2) + device = pred.device + + channel = pred.size(1) + kernel = _gaussian_kernel(channel, kernel_size, sigma, device) + mu_pred = F.conv2d(pred, kernel, groups=channel) + mu_target = F.conv2d(target, kernel, groups=channel) + + mu_pred_sq = mu_pred.pow(2) + mu_target_sq = mu_target.pow(2) + mu_pred_target = mu_pred * mu_target + + sigma_pred_sq = F.conv2d(pred * pred, kernel, groups=channel) - mu_pred_sq + sigma_target_sq = F.conv2d(target * target, kernel, groups=channel) - mu_target_sq + sigma_pred_target = F.conv2d(pred * target, kernel, groups=channel) - mu_pred_target + + UPPER = 2 * sigma_pred_target + C2 + LOWER = sigma_pred_sq + sigma_target_sq + C2 + + ssim_idx = ((2 * mu_pred_target + C1) * UPPER) / ((mu_pred_sq + mu_target_sq + C1) * LOWER) + + return reduce(ssim_idx, reduction) diff --git a/tests/metrics/functional/__init__.py b/tests/metrics/functional/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py new file mode 100644 index 0000000000..c9e1f0892f --- /dev/null +++ b/tests/metrics/functional/test_classification.py @@ -0,0 +1,341 @@ +from functools import partial + +import pytest +import torch +from sklearn.metrics import ( + accuracy_score as sk_accuracy, + precision_score as sk_precision, + recall_score as sk_recall, + f1_score as sk_f1_score, + fbeta_score as sk_fbeta_score, + confusion_matrix as sk_confusion_matrix, +) + +from pytorch_lightning import seed_everything +from pytorch_lightning.metrics.functional.classification import ( + to_onehot, + to_categorical, + get_num_classes, + stat_scores, + stat_scores_multiple_classes, + accuracy, + confusion_matrix, + precision, + recall, + fbeta_score, + f1_score, + _binary_clf_curve, + dice_score, + average_precision, + auroc, + precision_recall_curve, + roc, + auc, + iou, +) + + +@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ + pytest.param(sk_accuracy, accuracy, id='accuracy'), + pytest.param(partial(sk_precision, average='macro'), precision, id='precision'), + pytest.param(partial(sk_recall, average='macro'), recall, id='recall'), + pytest.param(partial(sk_f1_score, average='macro'), f1_score, id='f1_score'), + pytest.param(partial(sk_fbeta_score, average='macro', beta=2), partial(fbeta_score, beta=2), id='fbeta_score'), + pytest.param(sk_confusion_matrix, confusion_matrix, id='confusion_matrix') +]) +def test_against_sklearn(sklearn_metric, torch_metric): + """Compare PL metrics to sklearn version.""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # iterate over different label counts in predictions and target + for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]: + pred = torch.randint(n_cls_pred, (300,), device=device) + target = torch.randint(n_cls_target, (300,), device=device) + + sk_score = sklearn_metric(target.cpu().detach().numpy(), + pred.cpu().detach().numpy()) + sk_score = torch.tensor(sk_score, dtype=torch.float, device=device) + pl_score = torch_metric(pred, target) + assert torch.allclose(sk_score, pl_score) + + +def test_onehot(): + test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + expected = torch.stack([ + torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), + torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]) + ]) + + assert test_tensor.shape == (2, 5) + assert expected.shape == (2, 10, 5) + + onehot_classes = to_onehot(test_tensor, num_classes=10) + onehot_no_classes = to_onehot(test_tensor) + + assert torch.allclose(onehot_classes, onehot_no_classes) + + assert onehot_classes.shape == expected.shape + assert onehot_no_classes.shape == expected.shape + + assert torch.allclose(expected.to(onehot_no_classes), onehot_no_classes) + assert torch.allclose(expected.to(onehot_classes), onehot_classes) + + +def test_to_categorical(): + test_tensor = torch.stack([ + torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]), + torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)]) + ]).to(torch.float) + + expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) + assert expected.shape == (2, 5) + assert test_tensor.shape == (2, 10, 5) + + result = to_categorical(test_tensor) + + assert result.shape == expected.shape + assert torch.allclose(result, expected.to(result.dtype)) + + +@pytest.mark.parametrize(['pred', 'target', 'num_classes', 'expected_num_classes'], [ + pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10), + pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), + pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10), +]) +def test_get_num_classes(pred, target, num_classes, expected_num_classes): + assert get_num_classes(pred, target, num_classes) == expected_num_classes + + +@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', + 'expected_tn', 'expected_fn', 'expected_support'], [ + pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2), + pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2) +]) +def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support): + tp, fp, tn, fn, sup = stat_scores(pred, target, class_index=4) + + assert tp.item() == expected_tp + assert fp.item() == expected_fp + assert tn.item() == expected_tn + assert fn.item() == expected_fn + assert sup.item() == expected_support + + +@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', + 'expected_tn', 'expected_fn', 'expected_support'], [ + pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), + [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]), + pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), + [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]) +]) +def test_stat_scores_multiclass(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support): + tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target) + + assert torch.allclose(torch.tensor(expected_tp).to(tp), tp) + assert torch.allclose(torch.tensor(expected_fp).to(fp), fp) + assert torch.allclose(torch.tensor(expected_tn).to(tn), tn) + assert torch.allclose(torch.tensor(expected_fn).to(fn), fn) + assert torch.allclose(torch.tensor(expected_support).to(sup), sup) + + +def test_multilabel_accuracy(): + # Dense label indicator matrix format + y1 = torch.tensor([[0, 1, 1], [1, 0, 1]]) + y2 = torch.tensor([[0, 0, 1], [1, 0, 1]]) + + assert torch.allclose(accuracy(y1, y2, reduction='none'), torch.tensor([2 / 3, 1.])) + assert torch.allclose(accuracy(y1, y1, reduction='none'), torch.tensor([1., 1.])) + assert torch.allclose(accuracy(y2, y2, reduction='none'), torch.tensor([1., 1.])) + assert torch.allclose(accuracy(y2, torch.logical_not(y2), reduction='none'), torch.tensor([0., 0.])) + assert torch.allclose(accuracy(y1, torch.logical_not(y1), reduction='none'), torch.tensor([0., 0.])) + + with pytest.raises(RuntimeError): + accuracy(y2, torch.zeros_like(y2), reduction='none') + + +def test_accuracy(): + pred = torch.tensor([0, 1, 2, 3]) + target = torch.tensor([0, 1, 2, 2]) + acc = accuracy(pred, target) + + assert acc.item() == 0.75 + + pred = torch.tensor([0, 1, 2, 2]) + target = torch.tensor([0, 1, 1, 3]) + acc = accuracy(pred, target) + + assert acc.item() == 0.50 + + +def test_confusion_matrix(): + target = (torch.arange(120) % 3).view(-1, 1) + pred = target.clone() + cm = confusion_matrix(pred, target, normalize=True) + + assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])) + + pred = torch.zeros_like(pred) + cm = confusion_matrix(pred, target, normalize=True) + assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]])) + + +@pytest.mark.parametrize(['pred', 'target', 'expected_prec', 'expected_rec'], [ + pytest.param(torch.tensor([1., 0., 1., 0.]), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]), + pytest.param(to_onehot(torch.tensor([1., 0., 1., 0.])), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]) +]) +def test_precision_recall(pred, target, expected_prec, expected_rec): + prec = precision(pred, target, reduction='none') + rec = recall(pred, target, reduction='none') + + assert torch.allclose(torch.tensor(expected_prec).to(prec), prec) + assert torch.allclose(torch.tensor(expected_rec).to(rec), rec) + + +@pytest.mark.parametrize(['pred', 'target', 'beta', 'exp_score'], [ + pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 0.5, [0.5, 0.5]), + pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 1, [0.5, 0.5]), + pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 2, [0.5, 0.5]), +]) +def test_fbeta_score(pred, target, beta, exp_score): + score = fbeta_score(torch.tensor(pred), torch.tensor(target), beta, reduction='none') + assert torch.allclose(score, torch.tensor(exp_score)) + + score = fbeta_score(to_onehot(torch.tensor(pred)), torch.tensor(target), beta, reduction='none') + assert torch.allclose(score, torch.tensor(exp_score)) + + +@pytest.mark.parametrize(['pred', 'target', 'exp_score'], [ + pytest.param([0., 0., 0., 0.], [1., 1., 1., 1.], [0.0, 0.0]), + pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], [0.5, 0.5]), + pytest.param([1., 0., 1., 0.], [1., 0., 1., 0.], [1.0, 1.0]), +]) +def test_f1_score(pred, target, exp_score): + score = f1_score(torch.tensor(pred), torch.tensor(target), reduction='none') + assert torch.allclose(score, torch.tensor(exp_score)) + + score = f1_score(to_onehot(torch.tensor(pred)), torch.tensor(target), reduction='none') + assert torch.allclose(score, torch.tensor(exp_score)) + + +@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [ + pytest.param(1, 1., 42), + pytest.param(None, 1., 42), +]) +def test_binary_clf_curve(sample_weight, pos_label, exp_shape): + # TODO: move back the pred and target to test func arguments + # if you fix the array inside the function, you'd also have fix the shape, + # because when the array changes, you also have to fix the shape + seed_everything(0) + pred = torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100 + target = torch.tensor([0, 1] * 50, dtype=torch.int) + if sample_weight is not None: + sample_weight = torch.ones_like(pred) * sample_weight + + fps, tps, thresh = _binary_clf_curve(pred, target, sample_weight, pos_label) + + assert isinstance(tps, torch.Tensor) + assert isinstance(fps, torch.Tensor) + assert isinstance(thresh, torch.Tensor) + assert tps.shape == (exp_shape,) + assert fps.shape == (exp_shape,) + assert thresh.shape == (exp_shape,) + + +@pytest.mark.parametrize(['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], [ + pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4]) +]) +def test_pr_curve(pred, target, expected_p, expected_r, expected_t): + p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target)) + assert p.size() == r.size() + assert p.size(0) == t.size(0) + 1 + + assert torch.allclose(p, torch.tensor(expected_p).to(p)) + assert torch.allclose(r, torch.tensor(expected_r).to(r)) + assert torch.allclose(t, torch.tensor(expected_t).to(t)) + + +@pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [ + pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]), + pytest.param([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]), + pytest.param([1, 1], [1, 0], [0, 1], [0, 1]), + pytest.param([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]), + pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]), +]) +def test_roc_curve(pred, target, expected_tpr, expected_fpr): + fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target)) + + assert fpr.shape == tpr.shape + assert fpr.size(0) == thresh.size(0) + assert torch.allclose(fpr, torch.tensor(expected_fpr).to(fpr)) + assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr)) + + +@pytest.mark.parametrize(['pred', 'target', 'expected'], [ + pytest.param([0, 1, 0, 1], [0, 1, 0, 1], 1.), + pytest.param([1, 1, 0, 0], [0, 0, 1, 1], 0.), + pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.5), + pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 1.), + pytest.param([0.5, 0.5, 0.5, 0.5], [1, 1, 0, 0], 0.5), +]) +def test_auroc(pred, target, expected): + score = auroc(torch.tensor(pred), torch.tensor(target)).item() + assert score == expected + + +@pytest.mark.parametrize(['x', 'y', 'expected'], [ + pytest.param([0, 1], [0, 1], 0.5), + pytest.param([1, 0], [0, 1], 0.5), + pytest.param([1, 0, 0], [0, 1, 1], 0.5), + pytest.param([0, 1], [1, 1], 1), + pytest.param([0, 0.5, 1], [0, 0.5, 1], 0.5), +]) +def test_auc(x, y, expected): + # Test Area Under Curve (AUC) computation + assert auc(torch.tensor(x), torch.tensor(y)) == expected + + +@pytest.mark.parametrize(['scores', 'target', 'expected_score'], [ + # Check the average_precision_score of a constant predictor is + # the TPR + # Generate a dataset with 25% of positives + # And a constant score + # The precision is then the fraction of positive whatever the recall + # is, as there is only one threshold: + pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25), + # With threshold 0.8 : 1 TP and 2 TN and one FN + pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75), +]) +def test_average_precision(scores, target, expected_score): + assert average_precision(scores, target) == expected_score + + +@pytest.mark.parametrize(['pred', 'target', 'expected'], [ + pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.), + pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.), + pytest.param([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3), + pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.), +]) +def test_dice_score(pred, target, expected): + score = dice_score(torch.tensor(pred), torch.tensor(target)) + assert score == expected + + +@pytest.mark.parametrize(['half_ones', 'reduction', 'remove_bg', 'expected'], [ + pytest.param(False, 'none', False, torch.Tensor([1, 1, 1])), + pytest.param(False, 'elementwise_mean', False, torch.Tensor([1])), + pytest.param(False, 'none', True, torch.Tensor([1, 1])), + pytest.param(True, 'none', False, torch.Tensor([0.5, 0.5, 0.5])), + pytest.param(True, 'elementwise_mean', False, torch.Tensor([0.5])), + pytest.param(True, 'none', True, torch.Tensor([0.5, 0.5])), +]) +def test_iou(half_ones, reduction, remove_bg, expected): + pred = (torch.arange(120) % 3).view(-1, 1) + target = (torch.arange(120) % 3).view(-1, 1) + if half_ones: + pred[:60] = 1 + iou_val = iou(pred, target, remove_bg=remove_bg, reduction=reduction) + assert torch.allclose(iou_val, expected, atol=1e-9) + + +# example data taken from +# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py diff --git a/tests/metrics/functional/test_nlp.py b/tests/metrics/functional/test_nlp.py new file mode 100644 index 0000000000..2f1647270e --- /dev/null +++ b/tests/metrics/functional/test_nlp.py @@ -0,0 +1,66 @@ +import pytest +import torch +from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu + +from pytorch_lightning.metrics.functional.nlp import bleu_score + +# example taken from +# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.sentence_bleu +HYPOTHESIS1 = tuple( + "It is a guide to action which ensures that the military always obeys the commands of the party".split() +) +REFERENCE1 = tuple("It is a guide to action that ensures that the military will forever heed Party commands".split()) +REFERENCE2 = tuple( + "It is a guiding principle which makes the military forces always being under the command of the Party".split() +) +REFERENCE3 = tuple("It is the practical guide for the army always to heed the directions of the party".split()) + + +# example taken from +# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu +HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split() +HYP2 = "he read the book because he was interested in world history".split() + +REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split() +REF1B = "It is a guiding principle which makes the military force always being under the command of the Party".split() +REF1C = "It is the practical guide for the army always to heed the directions of the party".split() +REF2A = "he was interested in world history because he read the book".split() + +LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]] +HYPOTHESES = [HYP1, HYP2] + +# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction +smooth_func = SmoothingFunction().method2 + + +@pytest.mark.parametrize( + ["weights", "n_gram", "smooth_func", "smooth"], + [ + pytest.param([1], 1, None, False), + pytest.param([0.5, 0.5], 2, smooth_func, True), + pytest.param([0.333333, 0.333333, 0.333333], 3, None, False), + pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True), + ], +) +def test_bleu_score(weights, n_gram, smooth_func, smooth): + nltk_output = sentence_bleu( + [REFERENCE1, REFERENCE2, REFERENCE3], HYPOTHESIS1, weights=weights, smoothing_function=smooth_func + ) + pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth) + assert torch.allclose(pl_output, torch.tensor(nltk_output)) + + nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func) + pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth) + assert torch.allclose(pl_output, torch.tensor(nltk_output)) + + +def test_bleu_empty(): + hyp = [[]] + ref = [[[]]] + assert bleu_score(hyp, ref) == torch.tensor(0.0) + + +def test_no_4_gram(): + hyps = [["My", "full", "pytorch-lightning"]] + refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]] + assert bleu_score(hyps, refs) == torch.tensor(0.0) diff --git a/tests/metrics/functional/test_reduction.py b/tests/metrics/functional/test_reduction.py new file mode 100644 index 0000000000..71d2b6f773 --- /dev/null +++ b/tests/metrics/functional/test_reduction.py @@ -0,0 +1,15 @@ +import pytest +import torch + +from pytorch_lightning.metrics.functional.reduction import reduce + + +def test_reduce(): + start_tensor = torch.rand(50, 40, 30) + + assert torch.allclose(reduce(start_tensor, 'elementwise_mean'), torch.mean(start_tensor)) + assert torch.allclose(reduce(start_tensor, 'sum'), torch.sum(start_tensor)) + assert torch.allclose(reduce(start_tensor, 'none'), start_tensor) + + with pytest.raises(ValueError): + reduce(start_tensor, 'error_reduction') diff --git a/tests/metrics/functional/test_regression.py b/tests/metrics/functional/test_regression.py new file mode 100644 index 0000000000..6aae9027bf --- /dev/null +++ b/tests/metrics/functional/test_regression.py @@ -0,0 +1,144 @@ +import numpy as np +import pytest +import torch +from skimage.metrics import peak_signal_noise_ratio as ski_psnr +from skimage.metrics import structural_similarity as ski_ssim + +from pytorch_lightning.metrics.functional import ( + mae, + mse, + psnr, + rmse, + rmsle, + ssim +) + + +@pytest.mark.parametrize(['pred', 'target', 'expected'], [ + pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.25), + pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 3.0), +]) +def test_mse(pred, target, expected): + score = mse(torch.tensor(pred), torch.tensor(target)) + assert score.item() == expected + + +@pytest.mark.parametrize(['pred', 'target', 'expected'], [ + pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0), + pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.5), + pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 1.7321), +]) +def test_rmse(pred, target, expected): + score = rmse(torch.tensor(pred), torch.tensor(target)) + assert torch.allclose(score, torch.tensor(expected), atol=1e-3) + + +@pytest.mark.parametrize(['pred', 'target', 'expected'], [ + pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0), + pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.25), + pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 1.5), +]) +def test_mae(pred, target, expected): + score = mae(torch.tensor(pred), torch.tensor(target)) + assert score.item() == expected + + +@pytest.mark.parametrize(['pred', 'target', 'expected'], [ + pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0), + pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.0207), + pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 0.2841), +]) +def test_rmsle(pred, target, expected): + score = rmsle(torch.tensor(pred), torch.tensor(target)) + assert torch.allclose(score, torch.tensor(expected), atol=1e-3) + + +@pytest.mark.parametrize(['pred', 'target'], [ + pytest.param([0., 1., 2., 3.], [0., 1., 2., 3.]), + pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.]), + pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.]), +]) +def test_psnr_with_skimage(pred, target): + score = psnr(pred=torch.tensor(pred), + target=torch.tensor(target)) + sk_score = ski_psnr(np.array(pred), np.array(target), data_range=3) + assert torch.allclose(score, torch.tensor(sk_score, dtype=torch.float), atol=1e-3) + + +@pytest.mark.parametrize(['pred', 'target'], [ + pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.]), + pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.]), +]) +def test_psnr_base_e_wider_range(pred, target): + score = psnr(pred=torch.tensor(pred), + target=torch.tensor(target), + data_range=4, + base=2.718281828459045) + sk_score = ski_psnr(np.array(pred), np.array(target), data_range=4) * np.log(10) + assert torch.allclose(score, torch.tensor(sk_score, dtype=torch.float32), atol=1e-3) + + +@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ + pytest.param(ski_psnr, psnr, id='peak_signal_noise_ratio') +]) +def test_psnr_against_sklearn(sklearn_metric, torch_metric): + """Compare PL metrics to sklearn version.""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]: + pred = torch.randint(n_cls_pred, (500,), device=device, dtype=torch.float) + target = torch.randint(n_cls_target, (500,), device=device, dtype=torch.float) + + sk_score = sklearn_metric(target.cpu().detach().numpy(), + pred.cpu().detach().numpy(), + data_range=n_cls_target) + sk_score = torch.tensor(sk_score, dtype=torch.float, device=device) + pl_score = torch_metric(pred, target, data_range=n_cls_target) + assert torch.allclose(sk_score, pl_score) + + +@pytest.mark.parametrize(['size', 'channel', 'plus', 'multichannel'], [ + pytest.param(16, 1, 0.125, False), + pytest.param(32, 1, 0.25, False), + pytest.param(48, 3, 0.5, True), + pytest.param(64, 4, 0.75, True), + pytest.param(128, 5, 1, True) +]) +def test_ssim(size, channel, plus, multichannel): + device = "cuda" if torch.cuda.is_available() else "cpu" + pred = torch.rand(1, channel, size, size, device=device) + target = pred + plus + ssim_idx = ssim(pred, target) + np_pred = np.random.rand(size, size, channel) + if multichannel is False: + np_pred = np_pred[:, :, 0] + np_target = np.add(np_pred, plus) + sk_ssim_idx = ski_ssim(np_pred, np_target, win_size=11, multichannel=multichannel, gaussian_weights=True) + assert torch.allclose(ssim_idx, torch.tensor(sk_ssim_idx, dtype=torch.float, device=device), atol=1e-2, rtol=1e-2) + + ssim_idx = ssim(pred, pred) + assert torch.allclose(ssim_idx, torch.tensor(1.0, device=device)) + + +@pytest.mark.parametrize(['pred', 'target', 'kernel', 'sigma'], [ + pytest.param([1, 1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # shape + pytest.param([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape) + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma) + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma) + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma) + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input +]) +def test_ssim_invalid_inputs(pred, target, kernel, sigma): + pred_t = torch.rand(pred) + target_t = torch.rand(target, dtype=torch.float64) + with pytest.raises(TypeError): + ssim(pred_t, target_t) + + pred = torch.rand(pred) + target = torch.rand(target) + with pytest.raises(ValueError): + ssim(pred, target, kernel, sigma)