691 lines
22 KiB
Python
691 lines
22 KiB
Python
|
from collections import Sequence
|
||
|
from functools import wraps
|
||
|
from typing import Optional, Tuple, Callable
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from pytorch_lightning.metrics.functional.reduction import reduce
|
||
|
|
||
|
|
||
|
def to_onehot(
|
||
|
tensor: torch.Tensor,
|
||
|
n_classes: Optional[int] = None,
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Converts a dense label tensor to one-hot format
|
||
|
|
||
|
Args:
|
||
|
tensor: dense label tensor, with shape [N, d1, d2, ...]
|
||
|
|
||
|
n_classes: number of classes C
|
||
|
|
||
|
Output:
|
||
|
A sparse label tensor with shape [N, C, d1, d2, ...]
|
||
|
"""
|
||
|
if n_classes is None:
|
||
|
n_classes = int(tensor.max().detach().item() + 1)
|
||
|
dtype, device, shape = tensor.dtype, tensor.device, tensor.shape
|
||
|
tensor_onehot = torch.zeros(shape[0], n_classes, *shape[1:],
|
||
|
dtype=dtype, device=device)
|
||
|
index = tensor.long().unsqueeze(1).expand_as(tensor_onehot)
|
||
|
return tensor_onehot.scatter_(1, index, 1.0)
|
||
|
|
||
|
|
||
|
def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
|
||
|
"""
|
||
|
Converts a tensor of probabilities to a dense label tensor
|
||
|
|
||
|
Args:
|
||
|
tensor: probabilities to get the categorical label [N, d1, d2, ...]
|
||
|
argmax_dim: dimension to apply (default: 1)
|
||
|
|
||
|
Return:
|
||
|
A tensor with categorical labels [N, d2, ...]
|
||
|
"""
|
||
|
return torch.argmax(tensor, dim=argmax_dim)
|
||
|
|
||
|
|
||
|
def get_num_classes(
|
||
|
pred: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
num_classes: Optional[int],
|
||
|
) -> int:
|
||
|
"""
|
||
|
Returns the number of classes for a given prediction and target tensor.
|
||
|
|
||
|
Args:
|
||
|
pred: predicted values
|
||
|
target: true labels
|
||
|
num_classes: number of classes if known (default: None)
|
||
|
|
||
|
Return:
|
||
|
An integer that represents the number of classes.
|
||
|
"""
|
||
|
if num_classes is None:
|
||
|
if pred.ndim > target.ndim:
|
||
|
num_classes = pred.size(1)
|
||
|
else:
|
||
|
num_classes = int(target.max().detach().item() + 1)
|
||
|
return num_classes
|
||
|
|
||
|
|
||
|
def stat_scores(
|
||
|
pred: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
class_index: int, argmax_dim: int = 1,
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
|
"""
|
||
|
Calculates the number of true positive, falsepositivee, true negative
|
||
|
and false negative for a specific class
|
||
|
|
||
|
Args:
|
||
|
pred: prediction tensor
|
||
|
target: target tensor
|
||
|
class_index: class to calculate over
|
||
|
argmax_dim: if pred is a tensor of probabilities, this indicates the
|
||
|
axis the argmax transformation will be applied over
|
||
|
|
||
|
Return:
|
||
|
Tensors in the following order: True Positive, False Positive, True Negative, False Negative
|
||
|
|
||
|
"""
|
||
|
if pred.ndim == target.ndim + 1:
|
||
|
pred = to_categorical(pred, argmax_dim=argmax_dim)
|
||
|
|
||
|
tp = ((pred == class_index) * (target == class_index)).to(torch.long).sum()
|
||
|
fp = ((pred == class_index) * (target != class_index)).to(torch.long).sum()
|
||
|
tn = ((pred != class_index) * (target != class_index)).to(torch.long).sum()
|
||
|
fn = ((pred != class_index) * (target == class_index)).to(torch.long).sum()
|
||
|
|
||
|
return tp, fp, tn, fn
|
||
|
|
||
|
|
||
|
def stat_scores_multiple_classes(
|
||
|
pred: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
num_classes: Optional[int] = None,
|
||
|
argmax_dim: int = 1,
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
|
"""
|
||
|
Calls the stat_scores function iteratively for all classes, thus
|
||
|
calculating the number of true postive, false postive, true negative
|
||
|
and false negative for each class
|
||
|
|
||
|
Args:
|
||
|
pred: prediction tensor
|
||
|
target: target tensor
|
||
|
class_index: class to calculate over
|
||
|
argmax_dim: if pred is a tensor of probabilities, this indicates the
|
||
|
axis the argmax transformation will be applied over
|
||
|
|
||
|
Return:
|
||
|
Returns tensors for: tp, fp, tn, fn
|
||
|
|
||
|
"""
|
||
|
num_classes = get_num_classes(pred=pred, target=target,
|
||
|
num_classes=num_classes)
|
||
|
|
||
|
if pred.ndim == target.ndim + 1:
|
||
|
pred = to_categorical(pred, argmax_dim=argmax_dim)
|
||
|
|
||
|
tps = torch.zeros((num_classes,), device=pred.device)
|
||
|
fps = torch.zeros((num_classes,), device=pred.device)
|
||
|
tns = torch.zeros((num_classes,), device=pred.device)
|
||
|
fns = torch.zeros((num_classes,), device=pred.device)
|
||
|
|
||
|
for c in range(num_classes):
|
||
|
tps[c], fps[c], tns[c], fns[c] = stat_scores(pred=pred, target=target,
|
||
|
class_index=c)
|
||
|
|
||
|
return tps, fps, tns, fns
|
||
|
|
||
|
|
||
|
def accuracy(
|
||
|
pred: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
num_classes: Optional[int] = None,
|
||
|
reduction='elementwise_mean',
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Computes the accuracy classification score
|
||
|
|
||
|
Args:
|
||
|
pred: predicted labels
|
||
|
target: ground truth labels
|
||
|
num_classes: number of classes
|
||
|
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||
|
Available reduction methods:
|
||
|
|
||
|
- elementwise_mean: takes the mean
|
||
|
- none: pass array
|
||
|
- sum: add elements
|
||
|
|
||
|
Return:
|
||
|
A Tensor with the classification score.
|
||
|
"""
|
||
|
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred, target=target,
|
||
|
num_classes=num_classes)
|
||
|
|
||
|
if not (target > 0).any() and num_classes is None:
|
||
|
raise RuntimeError("cannot infer num_classes when target is all zero")
|
||
|
|
||
|
accuracies = (tps + tns) / (tps + tns + fps + fns)
|
||
|
|
||
|
return reduce(accuracies, reduction=reduction)
|
||
|
|
||
|
|
||
|
def confusion_matrix(
|
||
|
pred: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
normalize: bool = False,
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
|
||
|
in group i that were predicted in group j.
|
||
|
|
||
|
Args:
|
||
|
pred: estimated targets
|
||
|
target: ground truth labels
|
||
|
normalize: normalizes confusion matrix
|
||
|
|
||
|
Return:
|
||
|
Tensor, confusion matrix C [num_classes, num_classes ]
|
||
|
"""
|
||
|
num_classes = get_num_classes(pred, target, None)
|
||
|
|
||
|
d = target.size(-1)
|
||
|
batch_vec = torch.arange(target.size(-1))
|
||
|
# this will account for multilabel
|
||
|
unique_labels = batch_vec * num_classes ** 2 + target.view(-1) * num_classes + pred.view(-1)
|
||
|
|
||
|
bins = torch.bincount(unique_labels, minlength=d * num_classes ** 2)
|
||
|
cm = bins.reshape(d, num_classes, num_classes).squeeze().float()
|
||
|
|
||
|
if normalize:
|
||
|
cm = cm / cm.sum(-1)
|
||
|
|
||
|
return cm
|
||
|
|
||
|
|
||
|
def precision_recall(
|
||
|
pred: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
num_classes: Optional[int] = None,
|
||
|
reduction: str = 'elementwise_mean',
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
|
"""
|
||
|
Computes precision and recall for different thresholds
|
||
|
|
||
|
Args:
|
||
|
pred: estimated probabilities
|
||
|
target: ground-truth labels
|
||
|
num_classes: number of classes
|
||
|
reduction: method for reducing precision-recall values (default: takes the mean)
|
||
|
Available reduction methods:
|
||
|
|
||
|
- elementwise_mean: takes the mean
|
||
|
- none: pass array
|
||
|
- sum: add elements
|
||
|
|
||
|
Return:
|
||
|
Tensor with precision and recall
|
||
|
"""
|
||
|
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred,
|
||
|
target=target,
|
||
|
num_classes=num_classes)
|
||
|
|
||
|
tps = tps.to(torch.float)
|
||
|
fps = fps.to(torch.float)
|
||
|
fns = fns.to(torch.float)
|
||
|
|
||
|
precision = tps / (tps + fps)
|
||
|
recall = tps / (tps + fns)
|
||
|
|
||
|
precision = reduce(precision, reduction=reduction)
|
||
|
recall = reduce(recall, reduction=reduction)
|
||
|
return precision, recall
|
||
|
|
||
|
|
||
|
def precision(
|
||
|
pred: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
num_classes: Optional[int] = None,
|
||
|
reduction: str = 'elementwise_mean',
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Computes precision score.
|
||
|
|
||
|
Args:
|
||
|
pred: estimated probabilities
|
||
|
target: ground-truth labels
|
||
|
num_classes: number of classes
|
||
|
reduction: method for reducing precision values (default: takes the mean)
|
||
|
Available reduction methods:
|
||
|
|
||
|
- elementwise_mean: takes the mean
|
||
|
- none: pass array
|
||
|
- sum: add elements
|
||
|
|
||
|
Return:
|
||
|
Tensor with precision.
|
||
|
"""
|
||
|
return precision_recall(pred=pred, target=target,
|
||
|
num_classes=num_classes, reduction=reduction)[0]
|
||
|
|
||
|
|
||
|
def recall(
|
||
|
pred: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
num_classes: Optional[int] = None,
|
||
|
reduction: str = 'elementwise_mean',
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Computes recall score.
|
||
|
|
||
|
Args:
|
||
|
pred: estimated probabilities
|
||
|
target: ground-truth labels
|
||
|
num_classes: number of classes
|
||
|
reduction: method for reducing recall values (default: takes the mean)
|
||
|
Available reduction methods:
|
||
|
|
||
|
- elementwise_mean: takes the mean
|
||
|
- none: pass array
|
||
|
- sum: add elements
|
||
|
|
||
|
Return:
|
||
|
Tensor with recall.
|
||
|
"""
|
||
|
return precision_recall(pred=pred, target=target,
|
||
|
num_classes=num_classes, reduction=reduction)[1]
|
||
|
|
||
|
|
||
|
def fbeta_score(
|
||
|
pred: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
beta: float,
|
||
|
num_classes: Optional[int] = None,
|
||
|
reduction: str = 'elementwise_mean',
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Computes the F-beta score which is a weighted harmonic mean of precision and recall.
|
||
|
It ranges between 1 and 0, where 1 is perfect and the worst value is 0.
|
||
|
|
||
|
Args:
|
||
|
pred: estimated probabilities
|
||
|
target: ground-truth labels
|
||
|
beta: weights recall when combining the score.
|
||
|
beta < 1: more weight to precision.
|
||
|
beta > 1 more weight to recall
|
||
|
beta = 0: only precision
|
||
|
beta -> inf: only recall
|
||
|
num_classes: number of classes
|
||
|
reduction: method for reducing F-score (default: takes the mean)
|
||
|
Available reduction methods:
|
||
|
|
||
|
- elementwise_mean: takes the mean
|
||
|
- none: pass array
|
||
|
- sum: add elements.
|
||
|
|
||
|
Return:
|
||
|
Tensor with the value of F-score. It is a value between 0-1.
|
||
|
"""
|
||
|
prec, rec = precision_recall(pred=pred, target=target,
|
||
|
num_classes=num_classes,
|
||
|
reduction='none')
|
||
|
|
||
|
nom = (1 + beta ** 2) * prec * rec
|
||
|
denom = ((beta ** 2) * prec + rec)
|
||
|
fbeta = nom / denom
|
||
|
|
||
|
return reduce(fbeta, reduction=reduction)
|
||
|
|
||
|
|
||
|
def f1_score(
|
||
|
pred: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
num_classes: Optional[int] = None,
|
||
|
reduction='elementwise_mean',
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Computes F1-score a.k.a F-measure.
|
||
|
|
||
|
Args:
|
||
|
pred: estimated probabilities
|
||
|
target: ground-truth labels
|
||
|
num_classes: number of classes
|
||
|
reduction: method for reducing F1-score (default: takes the mean)
|
||
|
Available reduction methods:
|
||
|
|
||
|
- elementwise_mean: takes the mean
|
||
|
- none: pass array
|
||
|
- sum: add elements.
|
||
|
|
||
|
Return:
|
||
|
Tensor containing F1-score
|
||
|
"""
|
||
|
return fbeta_score(pred=pred, target=target, beta=1.,
|
||
|
num_classes=num_classes, reduction=reduction)
|
||
|
|
||
|
|
||
|
def _binary_clf_curve(
|
||
|
pred: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
sample_weight: Optional[Sequence] = None,
|
||
|
pos_label: int = 1.,
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
|
"""
|
||
|
adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py
|
||
|
"""
|
||
|
if sample_weight is not None and not isinstance(sample_weight, torch.Tensor):
|
||
|
sample_weight = torch.tensor(sample_weight, device=pred.device, dtype=torch.float)
|
||
|
|
||
|
# remove class dimension if necessary
|
||
|
if pred.ndim > target.ndim:
|
||
|
pred = pred[:, 0]
|
||
|
desc_score_indices = torch.argsort(pred, descending=True)
|
||
|
|
||
|
pred = pred[desc_score_indices]
|
||
|
target = target[desc_score_indices]
|
||
|
|
||
|
if sample_weight is not None:
|
||
|
weight = sample_weight[desc_score_indices]
|
||
|
else:
|
||
|
weight = 1.
|
||
|
|
||
|
# pred typically has many tied values. Here we extract
|
||
|
# the indices associated with the distinct values. We also
|
||
|
# concatenate a value for the end of the curve.
|
||
|
distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0]
|
||
|
threshold_idxs = torch.cat([distinct_value_indices,
|
||
|
torch.tensor([target.size(0) - 1])])
|
||
|
|
||
|
target = (target == pos_label).to(torch.long)
|
||
|
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]
|
||
|
|
||
|
if sample_weight is not None:
|
||
|
# express fps as a cumsum to ensure fps is increasing even in
|
||
|
# the presence of floating point errors
|
||
|
fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs]
|
||
|
|
||
|
else:
|
||
|
fps = 1 + threshold_idxs - tps
|
||
|
|
||
|
return fps, tps, pred[threshold_idxs]
|
||
|
|
||
|
|
||
|
def roc(
|
||
|
pred: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
sample_weight: Optional[Sequence] = None,
|
||
|
pos_label: int = 1.,
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
|
"""
|
||
|
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
|
||
|
|
||
|
Args:
|
||
|
pred: estimated probabilities
|
||
|
target: ground-truth labels
|
||
|
sample_weight: sample weights
|
||
|
pos_label: the label for the positive class (default: 1)
|
||
|
|
||
|
Return:
|
||
|
[Tensor, Tensor, Tensor]: false-positive rate (fpr), true-positive rate (tpr), thresholds
|
||
|
"""
|
||
|
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
|
||
|
sample_weight=sample_weight,
|
||
|
pos_label=pos_label)
|
||
|
|
||
|
# Add an extra threshold position
|
||
|
# to make sure that the curve starts at (0, 0)
|
||
|
tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
|
||
|
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
|
||
|
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])
|
||
|
|
||
|
if fps[-1] <= 0:
|
||
|
raise ValueError("No negative samples in targets, false positive value should be meaningless")
|
||
|
|
||
|
fpr = fps / fps[-1]
|
||
|
|
||
|
if tps[-1] <= 0:
|
||
|
raise ValueError("No positive samples in targets, true positive value should be meaningless")
|
||
|
|
||
|
tpr = tps / tps[-1]
|
||
|
|
||
|
return fpr, tpr, thresholds
|
||
|
|
||
|
|
||
|
def multiclass_roc(
|
||
|
pred: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
sample_weight: Optional[Sequence] = None,
|
||
|
num_classes: Optional[int] = None,
|
||
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||
|
"""
|
||
|
Computes the Receiver Operating Characteristic (ROC) for multiclass predictors.
|
||
|
|
||
|
Args:
|
||
|
pred: estimated probabilities
|
||
|
target: ground-truth labels
|
||
|
sample_weight: sample weights
|
||
|
num_classes: number of classes (default: None, computes automatically from data)
|
||
|
|
||
|
Return:
|
||
|
[num_classes, Tensor, Tensor, Tensor]: returns roc for each class.
|
||
|
number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds
|
||
|
"""
|
||
|
num_classes = get_num_classes(pred, target, num_classes)
|
||
|
|
||
|
class_roc_vals = []
|
||
|
for c in range(num_classes):
|
||
|
pred_c = pred[:, c]
|
||
|
|
||
|
class_roc_vals.append(roc(pred=pred_c, target=target,
|
||
|
sample_weight=sample_weight, pos_label=c))
|
||
|
|
||
|
return tuple(class_roc_vals)
|
||
|
|
||
|
|
||
|
def precision_recall_curve(
|
||
|
pred: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
sample_weight: Optional[Sequence] = None,
|
||
|
pos_label: int = 1.,
|
||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
|
"""
|
||
|
Computes precision-recall pairs for different thresholds.
|
||
|
|
||
|
Args:
|
||
|
pred: estimated probabilities
|
||
|
target: ground-truth labels
|
||
|
sample_weight: sample weights
|
||
|
pos_label: the label for the positive class (default: 1.)
|
||
|
|
||
|
Return:
|
||
|
[Tensor, Tensor, Tensor]: precision, recall, thresholds
|
||
|
"""
|
||
|
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
|
||
|
sample_weight=sample_weight,
|
||
|
pos_label=pos_label)
|
||
|
|
||
|
precision = tps / (tps + fps)
|
||
|
recall = tps / tps[-1]
|
||
|
|
||
|
# stop when full recall attained
|
||
|
# and reverse the outputs so recall is decreasing
|
||
|
last_ind = torch.where(tps == tps[-1])[0][0]
|
||
|
sl = slice(0, last_ind.item() + 1)
|
||
|
|
||
|
# need to call reversed explicitly, since including that to slice would
|
||
|
# introduce negative strides that are not yet supported in pytorch
|
||
|
precision = torch.cat([reversed(precision[sl]),
|
||
|
torch.ones(1, dtype=precision.dtype,
|
||
|
device=precision.device)])
|
||
|
|
||
|
recall = torch.cat([reversed(recall[sl]),
|
||
|
torch.zeros(1, dtype=recall.dtype,
|
||
|
device=recall.device)])
|
||
|
|
||
|
thresholds = torch.tensor(reversed(thresholds[sl]))
|
||
|
|
||
|
return precision, recall, thresholds
|
||
|
|
||
|
|
||
|
def multiclass_precision_recall_curve(
|
||
|
pred: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
sample_weight: Optional[Sequence] = None,
|
||
|
num_classes: Optional[int] = None,
|
||
|
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||
|
"""
|
||
|
Computes precision-recall pairs for different thresholds given a multiclass scores.
|
||
|
|
||
|
Args:
|
||
|
pred: estimated probabilities
|
||
|
target: ground-truth labels
|
||
|
sample_weight: sample weight
|
||
|
num_classes: number of classes
|
||
|
|
||
|
Return:
|
||
|
[num_classes, Tensor, Tensor, Tensor]: number of classes, precision, recall, thresholds
|
||
|
"""
|
||
|
num_classes = get_num_classes(pred, target, num_classes)
|
||
|
|
||
|
class_pr_vals = []
|
||
|
for c in range(num_classes):
|
||
|
pred_c = pred[:, c]
|
||
|
|
||
|
class_pr_vals.append(precision_recall_curve(
|
||
|
pred=pred_c,
|
||
|
target=target,
|
||
|
sample_weight=sample_weight, pos_label=c))
|
||
|
|
||
|
return tuple(class_pr_vals)
|
||
|
|
||
|
|
||
|
def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True):
|
||
|
"""
|
||
|
Computes Area Under the Curve (AUC) using the trapezoidal rule
|
||
|
|
||
|
Args:
|
||
|
x: x-coordinates
|
||
|
y: y-coordinates
|
||
|
reorder: reorder coordinates, so they are increasing.
|
||
|
|
||
|
Return:
|
||
|
AUC score (float)
|
||
|
"""
|
||
|
direction = 1.
|
||
|
|
||
|
if reorder:
|
||
|
# can't use lexsort here since it is not implemented for torch
|
||
|
order = torch.argsort(x)
|
||
|
x, y = x[order], y[order]
|
||
|
else:
|
||
|
dx = x[1:] - x[:-1]
|
||
|
if (dx < 0).any():
|
||
|
if (dx, 0).all():
|
||
|
direction = -1.
|
||
|
else:
|
||
|
raise ValueError("Reordering is not turned on, and "
|
||
|
"the x array is not increasing: %s" % x)
|
||
|
|
||
|
return direction * torch.trapz(y, x)
|
||
|
|
||
|
|
||
|
def auc_decorator(reorder: bool = True) -> Callable:
|
||
|
def wrapper(func_to_decorate: Callable) -> Callable:
|
||
|
@wraps(func_to_decorate)
|
||
|
def new_func(*args, **kwargs) -> torch.Tensor:
|
||
|
x, y = func_to_decorate(*args, **kwargs)[:2]
|
||
|
|
||
|
return auc(x, y, reorder=reorder)
|
||
|
|
||
|
return new_func
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
def multiclass_auc_decorator(reorder: bool = True) -> Callable:
|
||
|
def wrapper(func_to_decorate: Callable) -> Callable:
|
||
|
def new_func(*args, **kwargs) -> torch.Tensor:
|
||
|
results = []
|
||
|
for class_result in func_to_decorate(*args, **kwargs):
|
||
|
x, y = class_result[:2]
|
||
|
results.append(auc(x, y, reorder=reorder))
|
||
|
|
||
|
return torch.cat(results)
|
||
|
|
||
|
return new_func
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
def auroc(
|
||
|
pred: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
sample_weight: Optional[Sequence] = None,
|
||
|
pos_label: int = 1.,
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores
|
||
|
|
||
|
Args:
|
||
|
pred: estimated probabilities
|
||
|
target: ground-truth labels
|
||
|
sample_weight: sample weights
|
||
|
pos_label: the label for the positive class (default: 1.)
|
||
|
"""
|
||
|
|
||
|
@auc_decorator(reorder=True)
|
||
|
def _auroc(pred, target, sample_weight, pos_label):
|
||
|
return roc(pred, target, sample_weight, pos_label)
|
||
|
|
||
|
return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label)
|
||
|
|
||
|
|
||
|
def average_precision(
|
||
|
pred: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
sample_weight: Optional[Sequence] = None,
|
||
|
pos_label: int = 1.,
|
||
|
) -> torch.Tensor:
|
||
|
precision, recall, _ = precision_recall_curve(pred=pred, target=target,
|
||
|
sample_weight=sample_weight,
|
||
|
pos_label=pos_label)
|
||
|
# Return the step function integral
|
||
|
# The following works because the last entry of precision is
|
||
|
# guaranteed to be 1, as returned by precision_recall_curve
|
||
|
return -torch.sum(recall[1:] - recall[:-1] * precision[:-1])
|
||
|
|
||
|
|
||
|
def dice_score(
|
||
|
pred: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
bg: bool = False,
|
||
|
nan_score: float = 0.0,
|
||
|
no_fg_score: float = 0.0,
|
||
|
reduction: str = 'elementwise_mean',
|
||
|
) -> torch.Tensor:
|
||
|
n_classes = pred.shape[1]
|
||
|
bg = (1 - int(bool(bg)))
|
||
|
scores = torch.zeros(n_classes - bg, device=pred.device, dtype=torch.float32)
|
||
|
for i in range(bg, n_classes):
|
||
|
if not (target == i).any():
|
||
|
# no foreground class
|
||
|
scores[i - bg] += no_fg_score
|
||
|
continue
|
||
|
|
||
|
tp, fp, tn, fn = stat_scores(pred=pred, target=target, class_index=i)
|
||
|
|
||
|
denom = (2 * tp + fp + fn).to(torch.float)
|
||
|
|
||
|
if torch.isclose(denom, torch.zeros_like(denom)).any():
|
||
|
# nan result
|
||
|
score_cls = nan_score
|
||
|
else:
|
||
|
score_cls = (2 * tp).to(torch.float) / denom
|
||
|
|
||
|
scores[i - bg] += score_cls
|
||
|
return reduce(scores, reduction=reduction)
|