restore functional metrics (#3943)
* restore functional metrics * clean * fix
This commit is contained in:
parent
c77073f040
commit
b3d25262e9
|
@ -0,0 +1,31 @@
|
|||
from pytorch_lightning.metrics.functional.classification import (
|
||||
accuracy,
|
||||
auc,
|
||||
auroc,
|
||||
average_precision,
|
||||
confusion_matrix,
|
||||
dice_score,
|
||||
f1_score,
|
||||
fbeta_score,
|
||||
multiclass_precision_recall_curve,
|
||||
multiclass_roc,
|
||||
precision,
|
||||
precision_recall,
|
||||
precision_recall_curve,
|
||||
recall,
|
||||
roc,
|
||||
stat_scores,
|
||||
stat_scores_multiple_classes,
|
||||
to_categorical,
|
||||
to_onehot,
|
||||
iou,
|
||||
)
|
||||
from pytorch_lightning.metrics.functional.nlp import bleu_score
|
||||
from pytorch_lightning.metrics.functional.regression import (
|
||||
mae,
|
||||
mse,
|
||||
psnr,
|
||||
rmse,
|
||||
rmsle,
|
||||
ssim
|
||||
)
|
|
@ -0,0 +1,964 @@
|
|||
from collections import Sequence
|
||||
from functools import wraps
|
||||
from typing import Optional, Tuple, Callable
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from pytorch_lightning.metrics.functional.reduction import reduce
|
||||
from pytorch_lightning.utilities import rank_zero_warn, FLOAT16_EPSILON
|
||||
|
||||
|
||||
def to_onehot(
|
||||
tensor: torch.Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts a dense label tensor to one-hot format
|
||||
|
||||
Args:
|
||||
tensor: dense label tensor, with shape [N, d1, d2, ...]
|
||||
num_classes: number of classes C
|
||||
|
||||
Output:
|
||||
A sparse label tensor with shape [N, C, d1, d2, ...]
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([1, 2, 3])
|
||||
>>> to_onehot(x)
|
||||
tensor([[0, 1, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1]])
|
||||
|
||||
"""
|
||||
if num_classes is None:
|
||||
num_classes = int(tensor.max().detach().item() + 1)
|
||||
dtype, device, shape = tensor.dtype, tensor.device, tensor.shape
|
||||
tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:],
|
||||
dtype=dtype, device=device)
|
||||
index = tensor.long().unsqueeze(1).expand_as(tensor_onehot)
|
||||
return tensor_onehot.scatter_(1, index, 1.0)
|
||||
|
||||
|
||||
def to_categorical(
|
||||
tensor: torch.Tensor,
|
||||
argmax_dim: int = 1
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts a tensor of probabilities to a dense label tensor
|
||||
|
||||
Args:
|
||||
tensor: probabilities to get the categorical label [N, d1, d2, ...]
|
||||
argmax_dim: dimension to apply
|
||||
|
||||
Return:
|
||||
A tensor with categorical labels [N, d2, ...]
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]])
|
||||
>>> to_categorical(x)
|
||||
tensor([1, 0])
|
||||
|
||||
"""
|
||||
return torch.argmax(tensor, dim=argmax_dim)
|
||||
|
||||
|
||||
def get_num_classes(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Calculates the number of classes for a given prediction and target tensor.
|
||||
|
||||
Args:
|
||||
pred: predicted values
|
||||
target: true labels
|
||||
num_classes: number of classes if known
|
||||
|
||||
Return:
|
||||
An integer that represents the number of classes.
|
||||
"""
|
||||
num_target_classes = int(target.max().detach().item() + 1)
|
||||
num_pred_classes = int(pred.max().detach().item() + 1)
|
||||
num_all_classes = max(num_target_classes, num_pred_classes)
|
||||
|
||||
if num_classes is None:
|
||||
num_classes = num_all_classes
|
||||
elif num_classes != num_all_classes:
|
||||
rank_zero_warn(f'You have set {num_classes} number of classes if different from'
|
||||
f' predicted ({num_pred_classes}) and target ({num_target_classes}) number of classes')
|
||||
return num_classes
|
||||
|
||||
|
||||
def stat_scores(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
class_index: int, argmax_dim: int = 1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Calculates the number of true positive, false positive, true negative
|
||||
and false negative for a specific class
|
||||
|
||||
Args:
|
||||
pred: prediction tensor
|
||||
target: target tensor
|
||||
class_index: class to calculate over
|
||||
argmax_dim: if pred is a tensor of probabilities, this indicates the
|
||||
axis the argmax transformation will be applied over
|
||||
|
||||
Return:
|
||||
True Positive, False Positive, True Negative, False Negative, Support
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([1, 2, 3])
|
||||
>>> y = torch.tensor([0, 2, 3])
|
||||
>>> tp, fp, tn, fn, sup = stat_scores(x, y, class_index=1)
|
||||
>>> tp, fp, tn, fn, sup
|
||||
(tensor(0), tensor(1), tensor(2), tensor(0), tensor(0))
|
||||
|
||||
"""
|
||||
if pred.ndim == target.ndim + 1:
|
||||
pred = to_categorical(pred, argmax_dim=argmax_dim)
|
||||
|
||||
tp = ((pred == class_index) * (target == class_index)).to(torch.long).sum()
|
||||
fp = ((pred == class_index) * (target != class_index)).to(torch.long).sum()
|
||||
tn = ((pred != class_index) * (target != class_index)).to(torch.long).sum()
|
||||
fn = ((pred != class_index) * (target == class_index)).to(torch.long).sum()
|
||||
sup = (target == class_index).to(torch.long).sum()
|
||||
|
||||
return tp, fp, tn, fn, sup
|
||||
|
||||
|
||||
def stat_scores_multiple_classes(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
argmax_dim: int = 1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Calls the stat_scores function iteratively for all classes, thus
|
||||
calculating the number of true postive, false postive, true negative
|
||||
and false negative for each class
|
||||
|
||||
Args:
|
||||
pred: prediction tensor
|
||||
target: target tensor
|
||||
num_classes: number of classes if known
|
||||
argmax_dim: if pred is a tensor of probabilities, this indicates the
|
||||
axis the argmax transformation will be applied over
|
||||
|
||||
Return:
|
||||
True Positive, False Positive, True Negative, False Negative, Support
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([1, 2, 3])
|
||||
>>> y = torch.tensor([0, 2, 3])
|
||||
>>> tps, fps, tns, fns, sups = stat_scores_multiple_classes(x, y)
|
||||
>>> tps
|
||||
tensor([0., 0., 1., 1.])
|
||||
>>> fps
|
||||
tensor([0., 1., 0., 0.])
|
||||
>>> tns
|
||||
tensor([2., 2., 2., 2.])
|
||||
>>> fns
|
||||
tensor([1., 0., 0., 0.])
|
||||
>>> sups
|
||||
tensor([1., 0., 1., 1.])
|
||||
"""
|
||||
num_classes = get_num_classes(pred=pred, target=target,
|
||||
num_classes=num_classes)
|
||||
|
||||
if pred.ndim == target.ndim + 1:
|
||||
pred = to_categorical(pred, argmax_dim=argmax_dim)
|
||||
|
||||
tps = torch.zeros((num_classes,), device=pred.device)
|
||||
fps = torch.zeros((num_classes,), device=pred.device)
|
||||
tns = torch.zeros((num_classes,), device=pred.device)
|
||||
fns = torch.zeros((num_classes,), device=pred.device)
|
||||
sups = torch.zeros((num_classes,), device=pred.device)
|
||||
for c in range(num_classes):
|
||||
tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred, target=target, class_index=c)
|
||||
|
||||
return tps, fps, tns, fns, sups
|
||||
|
||||
|
||||
def accuracy(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction='elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the accuracy classification score
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: ground truth labels
|
||||
num_classes: number of classes
|
||||
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
A Tensor with the classification score.
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> accuracy(x, y)
|
||||
tensor(0.7500)
|
||||
|
||||
"""
|
||||
tps, fps, tns, fns, sups = stat_scores_multiple_classes(
|
||||
pred=pred, target=target, num_classes=num_classes)
|
||||
|
||||
if not (target > 0).any() and num_classes is None:
|
||||
raise RuntimeError("cannot infer num_classes when target is all zero")
|
||||
|
||||
if reduction in ('elementwise_mean', 'sum'):
|
||||
return reduce(sum(tps) / sum(sups), reduction=reduction)
|
||||
if reduction == 'none':
|
||||
return reduce(tps / sups, reduction=reduction)
|
||||
|
||||
|
||||
def confusion_matrix(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
normalize: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
|
||||
in group i that were predicted in group j.
|
||||
|
||||
Args:
|
||||
pred: estimated targets
|
||||
target: ground truth labels
|
||||
normalize: normalizes confusion matrix
|
||||
|
||||
Return:
|
||||
Tensor, confusion matrix C [num_classes, num_classes ]
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([1, 2, 3])
|
||||
>>> y = torch.tensor([0, 2, 3])
|
||||
>>> confusion_matrix(x, y)
|
||||
tensor([[0., 1., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 1., 0.],
|
||||
[0., 0., 0., 1.]])
|
||||
"""
|
||||
num_classes = get_num_classes(pred, target, None)
|
||||
|
||||
unique_labels = target.view(-1) * num_classes + pred.view(-1)
|
||||
|
||||
bins = torch.bincount(unique_labels, minlength=num_classes ** 2)
|
||||
cm = bins.reshape(num_classes, num_classes).squeeze().float()
|
||||
|
||||
if normalize:
|
||||
cm = cm / cm.sum(-1)
|
||||
|
||||
return cm
|
||||
|
||||
|
||||
def precision_recall(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes precision and recall for different thresholds
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
num_classes: number of classes
|
||||
reduction: method for reducing precision-recall values (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
Tensor with precision and recall
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> precision_recall(x, y)
|
||||
(tensor(0.7500), tensor(0.6250))
|
||||
|
||||
"""
|
||||
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes)
|
||||
|
||||
tps = tps.to(torch.float)
|
||||
fps = fps.to(torch.float)
|
||||
fns = fns.to(torch.float)
|
||||
|
||||
precision = tps / (tps + fps)
|
||||
recall = tps / (tps + fns)
|
||||
|
||||
# solution by justus, see https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/9
|
||||
precision[precision != precision] = 0
|
||||
recall[recall != recall] = 0
|
||||
|
||||
precision = reduce(precision, reduction=reduction)
|
||||
recall = reduce(recall, reduction=reduction)
|
||||
return precision, recall
|
||||
|
||||
|
||||
def precision(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes precision score.
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
num_classes: number of classes
|
||||
reduction: method for reducing precision values (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
Tensor with precision.
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> precision(x, y)
|
||||
tensor(0.7500)
|
||||
|
||||
"""
|
||||
return precision_recall(pred=pred, target=target,
|
||||
num_classes=num_classes, reduction=reduction)[0]
|
||||
|
||||
|
||||
def recall(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes recall score.
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
num_classes: number of classes
|
||||
reduction: method for reducing recall values (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
Tensor with recall.
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> recall(x, y)
|
||||
tensor(0.6250)
|
||||
"""
|
||||
return precision_recall(pred=pred, target=target,
|
||||
num_classes=num_classes, reduction=reduction)[1]
|
||||
|
||||
|
||||
def fbeta_score(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
beta: float,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the F-beta score which is a weighted harmonic mean of precision and recall.
|
||||
It ranges between 1 and 0, where 1 is perfect and the worst value is 0.
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
beta: weights recall when combining the score.
|
||||
beta < 1: more weight to precision.
|
||||
beta > 1 more weight to recall
|
||||
beta = 0: only precision
|
||||
beta -> inf: only recall
|
||||
num_classes: number of classes
|
||||
reduction: method for reducing F-score (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements.
|
||||
|
||||
Return:
|
||||
Tensor with the value of F-score. It is a value between 0-1.
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> fbeta_score(x, y, 0.2)
|
||||
tensor(0.7407)
|
||||
"""
|
||||
prec, rec = precision_recall(pred=pred, target=target,
|
||||
num_classes=num_classes,
|
||||
reduction='none')
|
||||
|
||||
nom = (1 + beta ** 2) * prec * rec
|
||||
denom = ((beta ** 2) * prec + rec)
|
||||
fbeta = nom / denom
|
||||
|
||||
# drop NaN after zero division
|
||||
fbeta[fbeta != fbeta] = 0
|
||||
|
||||
return reduce(fbeta, reduction=reduction)
|
||||
|
||||
|
||||
def f1_score(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction='elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the F1-score (a.k.a F-measure), which is the harmonic mean of the precision and recall.
|
||||
It ranges between 1 and 0, where 1 is perfect and the worst value is 0.
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
num_classes: number of classes
|
||||
reduction: method for reducing F1-score (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements.
|
||||
|
||||
Return:
|
||||
Tensor containing F1-score
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> f1_score(x, y)
|
||||
tensor(0.6667)
|
||||
"""
|
||||
return fbeta_score(pred=pred, target=target, beta=1.,
|
||||
num_classes=num_classes, reduction=reduction)
|
||||
|
||||
|
||||
def _binary_clf_curve(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
pos_label: int = 1.,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py
|
||||
"""
|
||||
if sample_weight is not None and not isinstance(sample_weight, torch.Tensor):
|
||||
sample_weight = torch.tensor(sample_weight, device=pred.device, dtype=torch.float)
|
||||
|
||||
# remove class dimension if necessary
|
||||
if pred.ndim > target.ndim:
|
||||
pred = pred[:, 0]
|
||||
desc_score_indices = torch.argsort(pred, descending=True)
|
||||
|
||||
pred = pred[desc_score_indices]
|
||||
target = target[desc_score_indices]
|
||||
|
||||
if sample_weight is not None:
|
||||
weight = sample_weight[desc_score_indices]
|
||||
else:
|
||||
weight = 1.
|
||||
|
||||
# pred typically has many tied values. Here we extract
|
||||
# the indices associated with the distinct values. We also
|
||||
# concatenate a value for the end of the curve.
|
||||
distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0]
|
||||
threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1)
|
||||
|
||||
target = (target == pos_label).to(torch.long)
|
||||
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]
|
||||
|
||||
if sample_weight is not None:
|
||||
# express fps as a cumsum to ensure fps is increasing even in
|
||||
# the presence of floating point errors
|
||||
fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs]
|
||||
else:
|
||||
fps = 1 + threshold_idxs - tps
|
||||
|
||||
return fps, tps, pred[threshold_idxs]
|
||||
|
||||
|
||||
def roc(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
pos_label: int = 1.,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
sample_weight: sample weights
|
||||
pos_label: the label for the positive class
|
||||
|
||||
Return:
|
||||
false-positive rate (fpr), true-positive rate (tpr), thresholds
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> fpr, tpr, thresholds = roc(x, y)
|
||||
>>> fpr
|
||||
tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000])
|
||||
>>> tpr
|
||||
tensor([0., 0., 0., 1., 1.])
|
||||
>>> thresholds
|
||||
tensor([4, 3, 2, 1, 0])
|
||||
|
||||
"""
|
||||
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
|
||||
sample_weight=sample_weight,
|
||||
pos_label=pos_label)
|
||||
|
||||
# Add an extra threshold position
|
||||
# to make sure that the curve starts at (0, 0)
|
||||
tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
|
||||
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
|
||||
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])
|
||||
|
||||
if fps[-1] <= 0:
|
||||
raise ValueError("No negative samples in targets, false positive value should be meaningless")
|
||||
|
||||
fpr = fps / fps[-1]
|
||||
|
||||
if tps[-1] <= 0:
|
||||
raise ValueError("No positive samples in targets, true positive value should be meaningless")
|
||||
|
||||
tpr = tps / tps[-1]
|
||||
|
||||
return fpr, tpr, thresholds
|
||||
|
||||
|
||||
def multiclass_roc(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
num_classes: Optional[int] = None,
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Computes the Receiver Operating Characteristic (ROC) for multiclass predictors.
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
sample_weight: sample weights
|
||||
num_classes: number of classes (default: None, computes automatically from data)
|
||||
|
||||
Return:
|
||||
returns roc for each class.
|
||||
Number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||
... [0.05, 0.85, 0.05, 0.05],
|
||||
... [0.05, 0.05, 0.85, 0.05],
|
||||
... [0.05, 0.05, 0.05, 0.85]])
|
||||
>>> target = torch.tensor([0, 1, 3, 2])
|
||||
>>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE
|
||||
((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
|
||||
(tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
|
||||
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])),
|
||||
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])))
|
||||
"""
|
||||
num_classes = get_num_classes(pred, target, num_classes)
|
||||
|
||||
class_roc_vals = []
|
||||
for c in range(num_classes):
|
||||
pred_c = pred[:, c]
|
||||
|
||||
class_roc_vals.append(roc(pred=pred_c, target=target,
|
||||
sample_weight=sample_weight, pos_label=c))
|
||||
|
||||
return tuple(class_roc_vals)
|
||||
|
||||
|
||||
def precision_recall_curve(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
pos_label: int = 1.,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes precision-recall pairs for different thresholds.
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
sample_weight: sample weights
|
||||
pos_label: the label for the positive class
|
||||
|
||||
Return:
|
||||
precision, recall, thresholds
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> precision, recall, thresholds = precision_recall_curve(pred, target)
|
||||
>>> precision
|
||||
tensor([0.3333, 0.0000, 0.0000, 1.0000])
|
||||
>>> recall
|
||||
tensor([1., 0., 0., 0.])
|
||||
>>> thresholds
|
||||
tensor([1, 2, 3])
|
||||
|
||||
"""
|
||||
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
|
||||
sample_weight=sample_weight,
|
||||
pos_label=pos_label)
|
||||
|
||||
precision = tps / (tps + fps)
|
||||
recall = tps / tps[-1]
|
||||
|
||||
# stop when full recall attained
|
||||
# and reverse the outputs so recall is decreasing
|
||||
last_ind = torch.where(tps == tps[-1])[0][0]
|
||||
sl = slice(0, last_ind.item() + 1)
|
||||
|
||||
# need to call reversed explicitly, since including that to slice would
|
||||
# introduce negative strides that are not yet supported in pytorch
|
||||
precision = torch.cat([reversed(precision[sl]),
|
||||
torch.ones(1, dtype=precision.dtype,
|
||||
device=precision.device)])
|
||||
|
||||
recall = torch.cat([reversed(recall[sl]),
|
||||
torch.zeros(1, dtype=recall.dtype,
|
||||
device=recall.device)])
|
||||
|
||||
thresholds = torch.tensor(reversed(thresholds[sl]))
|
||||
|
||||
return precision, recall, thresholds
|
||||
|
||||
|
||||
def multiclass_precision_recall_curve(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
num_classes: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes precision-recall pairs for different thresholds given a multiclass scores.
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
sample_weight: sample weight
|
||||
num_classes: number of classes
|
||||
|
||||
Return:
|
||||
number of classes, precision, recall, thresholds
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||
... [0.05, 0.85, 0.05, 0.05],
|
||||
... [0.05, 0.05, 0.85, 0.05],
|
||||
... [0.05, 0.05, 0.05, 0.85]])
|
||||
>>> target = torch.tensor([0, 1, 3, 2])
|
||||
>>> nb_classes, precision, recall, thresholds = multiclass_precision_recall_curve(pred, target)
|
||||
>>> nb_classes
|
||||
(tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500]))
|
||||
>>> precision
|
||||
(tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500]))
|
||||
>>> recall
|
||||
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500]))
|
||||
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
|
||||
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500]))
|
||||
"""
|
||||
num_classes = get_num_classes(pred, target, num_classes)
|
||||
|
||||
class_pr_vals = []
|
||||
for c in range(num_classes):
|
||||
pred_c = pred[:, c]
|
||||
|
||||
class_pr_vals.append(precision_recall_curve(
|
||||
pred=pred_c,
|
||||
target=target,
|
||||
sample_weight=sample_weight, pos_label=c))
|
||||
|
||||
return tuple(class_pr_vals)
|
||||
|
||||
|
||||
def auc(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
reorder: bool = True
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes Area Under the Curve (AUC) using the trapezoidal rule
|
||||
|
||||
Args:
|
||||
x: x-coordinates
|
||||
y: y-coordinates
|
||||
reorder: reorder coordinates, so they are increasing
|
||||
|
||||
Return:
|
||||
Tensor containing AUC score (float)
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> auc(x, y)
|
||||
tensor(4.)
|
||||
"""
|
||||
direction = 1.
|
||||
|
||||
if reorder:
|
||||
# can't use lexsort here since it is not implemented for torch
|
||||
order = torch.argsort(x)
|
||||
x, y = x[order], y[order]
|
||||
else:
|
||||
dx = x[1:] - x[:-1]
|
||||
if (dx < 0).any():
|
||||
if (dx, 0).all():
|
||||
direction = -1.
|
||||
else:
|
||||
raise ValueError("Reordering is not turned on, and "
|
||||
"the x array is not increasing: %s" % x)
|
||||
|
||||
return direction * torch.trapz(y, x)
|
||||
|
||||
|
||||
def auc_decorator(reorder: bool = True) -> Callable:
|
||||
def wrapper(func_to_decorate: Callable) -> Callable:
|
||||
@wraps(func_to_decorate)
|
||||
def new_func(*args, **kwargs) -> torch.Tensor:
|
||||
x, y = func_to_decorate(*args, **kwargs)[:2]
|
||||
|
||||
return auc(x, y, reorder=reorder)
|
||||
|
||||
return new_func
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def multiclass_auc_decorator(reorder: bool = True) -> Callable:
|
||||
def wrapper(func_to_decorate: Callable) -> Callable:
|
||||
def new_func(*args, **kwargs) -> torch.Tensor:
|
||||
results = []
|
||||
for class_result in func_to_decorate(*args, **kwargs):
|
||||
x, y = class_result[:2]
|
||||
results.append(auc(x, y, reorder=reorder))
|
||||
|
||||
return torch.cat(results)
|
||||
|
||||
return new_func
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def auroc(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
pos_label: int = 1.,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
sample_weight: sample weights
|
||||
pos_label: the label for the positive class
|
||||
|
||||
Return:
|
||||
Tensor containing ROCAUC score
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> auroc(x, y)
|
||||
tensor(0.3333)
|
||||
"""
|
||||
|
||||
@auc_decorator(reorder=True)
|
||||
def _auroc(pred, target, sample_weight, pos_label):
|
||||
return roc(pred, target, sample_weight, pos_label)
|
||||
|
||||
return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label)
|
||||
|
||||
|
||||
def average_precision(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
pos_label: int = 1.,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute average precision from prediction scores
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
sample_weight: sample weights
|
||||
pos_label: the label for the positive class
|
||||
|
||||
Return:
|
||||
Tensor containing average precision score
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> average_precision(x, y)
|
||||
tensor(0.3333)
|
||||
"""
|
||||
precision, recall, _ = precision_recall_curve(pred=pred, target=target,
|
||||
sample_weight=sample_weight,
|
||||
pos_label=pos_label)
|
||||
# Return the step function integral
|
||||
# The following works because the last entry of precision is
|
||||
# guaranteed to be 1, as returned by precision_recall_curve
|
||||
return -torch.sum((recall[1:] - recall[:-1]) * precision[:-1])
|
||||
|
||||
|
||||
def dice_score(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
bg: bool = False,
|
||||
nan_score: float = 0.0,
|
||||
no_fg_score: float = 0.0,
|
||||
reduction: str = 'elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute dice score from prediction scores
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
bg: whether to also compute dice for the background
|
||||
nan_score: score to return, if a NaN occurs during computation
|
||||
no_fg_score: score to return, if no foreground pixel was found in target
|
||||
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
Tensor containing dice score
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||
... [0.05, 0.85, 0.05, 0.05],
|
||||
... [0.05, 0.05, 0.85, 0.05],
|
||||
... [0.05, 0.05, 0.05, 0.85]])
|
||||
>>> target = torch.tensor([0, 1, 3, 2])
|
||||
>>> dice_score(pred, target)
|
||||
tensor(0.3333)
|
||||
|
||||
"""
|
||||
num_classes = pred.shape[1]
|
||||
bg = (1 - int(bool(bg)))
|
||||
scores = torch.zeros(num_classes - bg, device=pred.device, dtype=torch.float32)
|
||||
for i in range(bg, num_classes):
|
||||
if not (target == i).any():
|
||||
# no foreground class
|
||||
scores[i - bg] += no_fg_score
|
||||
continue
|
||||
|
||||
tp, fp, tn, fn, sup = stat_scores(pred=pred, target=target, class_index=i)
|
||||
denom = (2 * tp + fp + fn).to(torch.float)
|
||||
# nan result
|
||||
score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else nan_score
|
||||
|
||||
scores[i - bg] += score_cls
|
||||
return reduce(scores, reduction=reduction)
|
||||
|
||||
|
||||
def iou(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
remove_bg: bool = False,
|
||||
reduction: str = 'elementwise_mean'
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Intersection over union, or Jaccard index calculation.
|
||||
|
||||
Args:
|
||||
pred: Tensor containing predictions
|
||||
target: Tensor containing targets
|
||||
num_classes: Optionally specify the number of classes
|
||||
remove_bg: Flag to state whether a background class has been included
|
||||
within input parameters. If true, will remove background class. If
|
||||
false, return IoU over all classes
|
||||
Assumes that background is '0' class in input tensor
|
||||
reduction: a method for reducing IoU over labels (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
IoU score : Tensor containing single value if reduction is
|
||||
'elementwise_mean', or number of classes if reduction is 'none'
|
||||
|
||||
Example:
|
||||
|
||||
>>> target = torch.randint(0, 1, (10, 25, 25))
|
||||
>>> pred = torch.tensor(target)
|
||||
>>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15]
|
||||
>>> iou(pred, target)
|
||||
tensor(0.4914)
|
||||
|
||||
"""
|
||||
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred, target, num_classes)
|
||||
if remove_bg:
|
||||
tps = tps[1:]
|
||||
fps = fps[1:]
|
||||
fns = fns[1:]
|
||||
denom = fps + fns + tps
|
||||
denom[denom == 0] = torch.tensor(FLOAT16_EPSILON).type_as(denom)
|
||||
iou = tps / denom
|
||||
return reduce(iou, reduction=reduction)
|
|
@ -0,0 +1,92 @@
|
|||
# referenced from
|
||||
# Library Name: torchtext
|
||||
# Authors: torchtext authors and @sluks
|
||||
# Date: 2020-07-18
|
||||
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
|
||||
from collections import Counter
|
||||
from typing import Sequence, List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:
|
||||
"""Counting how many times each word appears in a given text with ngram
|
||||
|
||||
Args:
|
||||
ngram_input_list: A list of translated text or reference texts
|
||||
n_gram: gram value ranged 1 to 4
|
||||
|
||||
Return:
|
||||
ngram_counter: a collections.Counter object of ngram
|
||||
"""
|
||||
|
||||
ngram_counter = Counter()
|
||||
|
||||
for i in range(1, n_gram + 1):
|
||||
for j in range(len(ngram_input_list) - i + 1):
|
||||
ngram_key = tuple(ngram_input_list[j : i + j])
|
||||
ngram_counter[ngram_key] += 1
|
||||
|
||||
return ngram_counter
|
||||
|
||||
|
||||
def bleu_score(
|
||||
translate_corpus: Sequence[str], reference_corpus: Sequence[str], n_gram: int = 4, smooth: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""Calculate BLEU score of machine translated text with one or more references.
|
||||
|
||||
Args:
|
||||
translate_corpus: An iterable of machine translated corpus
|
||||
reference_corpus: An iterable of iterables of reference corpus
|
||||
n_gram: Gram value ranged from 1 to 4 (Default 4)
|
||||
smooth: Whether or not to apply smoothing – Lin et al. 2004
|
||||
|
||||
Return:
|
||||
A Tensor with BLEU Score
|
||||
|
||||
Example:
|
||||
|
||||
>>> translate_corpus = ['the cat is on the mat'.split()]
|
||||
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
|
||||
>>> bleu_score(translate_corpus, reference_corpus)
|
||||
tensor(0.7598)
|
||||
"""
|
||||
|
||||
assert len(translate_corpus) == len(reference_corpus)
|
||||
numerator = torch.zeros(n_gram)
|
||||
denominator = torch.zeros(n_gram)
|
||||
precision_scores = torch.zeros(n_gram)
|
||||
c = 0.0
|
||||
r = 0.0
|
||||
for (translation, references) in zip(translate_corpus, reference_corpus):
|
||||
c += len(translation)
|
||||
ref_len_list = [len(ref) for ref in references]
|
||||
ref_len_diff = [abs(len(translation) - x) for x in ref_len_list]
|
||||
r += ref_len_list[ref_len_diff.index(min(ref_len_diff))]
|
||||
translation_counter = _count_ngram(translation, n_gram)
|
||||
reference_counter = Counter()
|
||||
for ref in references:
|
||||
reference_counter |= _count_ngram(ref, n_gram)
|
||||
|
||||
ngram_counter_clip = translation_counter & reference_counter
|
||||
for counter_clip in ngram_counter_clip:
|
||||
numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip]
|
||||
|
||||
for counter in translation_counter:
|
||||
denominator[len(counter) - 1] += translation_counter[counter]
|
||||
|
||||
trans_len = torch.tensor(c)
|
||||
ref_len = torch.tensor(r)
|
||||
if min(numerator) == 0.0:
|
||||
return torch.tensor(0.0)
|
||||
|
||||
if smooth:
|
||||
precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram))
|
||||
else:
|
||||
precision_scores = numerator / denominator
|
||||
log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores)
|
||||
geometric_mean = torch.exp(torch.sum(log_precision_scores))
|
||||
brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len))
|
||||
bleu = brevity_penalty * geometric_mean
|
||||
|
||||
return bleu
|
|
@ -0,0 +1,24 @@
|
|||
import torch
|
||||
|
||||
|
||||
def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
|
||||
"""
|
||||
Reduces a given tensor by a given reduction method
|
||||
|
||||
Args:
|
||||
to_reduce : the tensor, which shall be reduced
|
||||
reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum')
|
||||
|
||||
Return:
|
||||
reduced Tensor
|
||||
|
||||
Raise:
|
||||
ValueError if an invalid reduction parameter was given
|
||||
"""
|
||||
if reduction == 'elementwise_mean':
|
||||
return torch.mean(to_reduce)
|
||||
if reduction == 'none':
|
||||
return to_reduce
|
||||
if reduction == 'sum':
|
||||
return torch.sum(to_reduce)
|
||||
raise ValueError('Reduction parameter unknown.')
|
|
@ -0,0 +1,297 @@
|
|||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from pytorch_lightning.metrics.functional.reduction import reduce
|
||||
|
||||
|
||||
def mse(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
reduction: str = 'elementwise_mean'
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes mean squared error
|
||||
|
||||
Args:
|
||||
pred: estimated labels
|
||||
target: ground truth labels
|
||||
reduction: method for reducing mse (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
Tensor with MSE
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0., 1, 2, 3])
|
||||
>>> y = torch.tensor([0., 1, 2, 2])
|
||||
>>> mse(x, y)
|
||||
tensor(0.2500)
|
||||
|
||||
"""
|
||||
mse = F.mse_loss(pred, target, reduction='none')
|
||||
mse = reduce(mse, reduction=reduction)
|
||||
return mse
|
||||
|
||||
|
||||
def rmse(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
reduction: str = 'elementwise_mean'
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes root mean squared error
|
||||
|
||||
Args:
|
||||
pred: estimated labels
|
||||
target: ground truth labels
|
||||
reduction: method for reducing rmse (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
Tensor with RMSE
|
||||
|
||||
|
||||
>>> x = torch.tensor([0., 1, 2, 3])
|
||||
>>> y = torch.tensor([0., 1, 2, 2])
|
||||
>>> rmse(x, y)
|
||||
tensor(0.5000)
|
||||
|
||||
"""
|
||||
rmse = torch.sqrt(mse(pred, target, reduction=reduction))
|
||||
return rmse
|
||||
|
||||
|
||||
def mae(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
reduction: str = 'elementwise_mean'
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes mean absolute error
|
||||
|
||||
Args:
|
||||
pred: estimated labels
|
||||
target: ground truth labels
|
||||
reduction: method for reducing mae (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
Tensor with MAE
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0., 1, 2, 3])
|
||||
>>> y = torch.tensor([0., 1, 2, 2])
|
||||
>>> mae(x, y)
|
||||
tensor(0.2500)
|
||||
|
||||
"""
|
||||
mae = F.l1_loss(pred, target, reduction='none')
|
||||
mae = reduce(mae, reduction=reduction)
|
||||
return mae
|
||||
|
||||
|
||||
def rmsle(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
reduction: str = 'elementwise_mean'
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes root mean squared log error
|
||||
|
||||
Args:
|
||||
pred: estimated labels
|
||||
target: ground truth labels
|
||||
reduction: method for reducing rmsle (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
Tensor with RMSLE
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0., 1, 2, 3])
|
||||
>>> y = torch.tensor([0., 1, 2, 2])
|
||||
>>> rmsle(x, y)
|
||||
tensor(0.0207)
|
||||
|
||||
"""
|
||||
rmsle = mse(torch.log(pred + 1), torch.log(target + 1), reduction=reduction)
|
||||
return rmsle
|
||||
|
||||
|
||||
def psnr(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
data_range: float = None,
|
||||
base: float = 10.0,
|
||||
reduction: str = 'elementwise_mean'
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the peak signal-to-noise ratio
|
||||
|
||||
Args:
|
||||
pred: estimated signal
|
||||
target: groun truth signal
|
||||
data_range: the range of the data. If None, it is determined from the data (max - min)
|
||||
base: a base of a logarithm to use (default: 10)
|
||||
reduction: method for reducing psnr (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum add elements
|
||||
|
||||
Return:
|
||||
Tensor with PSNR score
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
|
||||
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
|
||||
>>> psnr(pred, target)
|
||||
tensor(2.5527)
|
||||
|
||||
"""
|
||||
|
||||
if data_range is None:
|
||||
data_range = max(target.max() - target.min(), pred.max() - pred.min())
|
||||
else:
|
||||
data_range = torch.tensor(float(data_range))
|
||||
|
||||
mse_score = mse(pred.view(-1), target.view(-1), reduction=reduction)
|
||||
psnr_base_e = 2 * torch.log(data_range) - torch.log(mse_score)
|
||||
psnr = psnr_base_e * (10 / torch.log(torch.tensor(base)))
|
||||
return psnr
|
||||
|
||||
|
||||
def _gaussian_kernel(channel, kernel_size, sigma, device):
|
||||
def gaussian(kernel_size, sigma, device):
|
||||
gauss = torch.arange(
|
||||
start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=torch.float32, device=device
|
||||
)
|
||||
gauss = torch.exp(-gauss.pow(2) / (2 * pow(sigma, 2)))
|
||||
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)
|
||||
|
||||
gaussian_kernel_x = gaussian(kernel_size[0], sigma[0], device)
|
||||
gaussian_kernel_y = gaussian(kernel_size[1], sigma[1], device)
|
||||
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)
|
||||
|
||||
return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])
|
||||
|
||||
|
||||
def ssim(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
kernel_size: Sequence[int] = (11, 11),
|
||||
sigma: Sequence[float] = (1.5, 1.5),
|
||||
reduction: str = "elementwise_mean",
|
||||
data_range: float = None,
|
||||
k1: float = 0.01,
|
||||
k2: float = 0.03
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes Structual Similarity Index Measure
|
||||
|
||||
Args:
|
||||
pred: Estimated image
|
||||
target: Ground truth image
|
||||
kernel_size: Size of the gaussian kernel. Default: (11, 11)
|
||||
sigma: Standard deviation of the gaussian kernel. Default: (1.5, 1.5)
|
||||
reduction: A method for reducing ssim over all elements in the ``pred`` tensor. Default: ``elementwise_mean``
|
||||
|
||||
Available reduction methods:
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass away
|
||||
- sum: add elements
|
||||
|
||||
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
|
||||
k1: Parameter of SSIM. Default: 0.01
|
||||
k2: Parameter of SSIM. Default: 0.03
|
||||
|
||||
Returns:
|
||||
A Tensor with SSIM
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.rand([16, 1, 16, 16])
|
||||
>>> target = pred * 1.25
|
||||
>>> ssim(pred, target)
|
||||
tensor(0.9520)
|
||||
"""
|
||||
|
||||
if pred.dtype != target.dtype:
|
||||
raise TypeError(
|
||||
"Expected `pred` and `target` to have the same data type."
|
||||
f" Got pred: {pred.dtype} and target: {target.dtype}."
|
||||
)
|
||||
|
||||
if pred.shape != target.shape:
|
||||
raise ValueError(
|
||||
"Expected `pred` and `target` to have the same shape."
|
||||
f" Got pred: {pred.shape} and target: {target.shape}."
|
||||
)
|
||||
|
||||
if len(pred.shape) != 4 or len(target.shape) != 4:
|
||||
raise ValueError(
|
||||
"Expected `pred` and `target` to have BxCxHxW shape."
|
||||
f" Got pred: {pred.shape} and target: {target.shape}."
|
||||
)
|
||||
|
||||
if len(kernel_size) != 2 or len(sigma) != 2:
|
||||
raise ValueError(
|
||||
"Expected `kernel_size` and `sigma` to have the length of two."
|
||||
f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}."
|
||||
)
|
||||
|
||||
if any(x % 2 == 0 or x <= 0 for x in kernel_size):
|
||||
raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.")
|
||||
|
||||
if any(y <= 0 for y in sigma):
|
||||
raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.")
|
||||
|
||||
if data_range is None:
|
||||
data_range = max(pred.max() - pred.min(), target.max() - target.min())
|
||||
|
||||
C1 = pow(k1 * data_range, 2)
|
||||
C2 = pow(k2 * data_range, 2)
|
||||
device = pred.device
|
||||
|
||||
channel = pred.size(1)
|
||||
kernel = _gaussian_kernel(channel, kernel_size, sigma, device)
|
||||
mu_pred = F.conv2d(pred, kernel, groups=channel)
|
||||
mu_target = F.conv2d(target, kernel, groups=channel)
|
||||
|
||||
mu_pred_sq = mu_pred.pow(2)
|
||||
mu_target_sq = mu_target.pow(2)
|
||||
mu_pred_target = mu_pred * mu_target
|
||||
|
||||
sigma_pred_sq = F.conv2d(pred * pred, kernel, groups=channel) - mu_pred_sq
|
||||
sigma_target_sq = F.conv2d(target * target, kernel, groups=channel) - mu_target_sq
|
||||
sigma_pred_target = F.conv2d(pred * target, kernel, groups=channel) - mu_pred_target
|
||||
|
||||
UPPER = 2 * sigma_pred_target + C2
|
||||
LOWER = sigma_pred_sq + sigma_target_sq + C2
|
||||
|
||||
ssim_idx = ((2 * mu_pred_target + C1) * UPPER) / ((mu_pred_sq + mu_target_sq + C1) * LOWER)
|
||||
|
||||
return reduce(ssim_idx, reduction)
|
|
@ -0,0 +1,341 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import (
|
||||
accuracy_score as sk_accuracy,
|
||||
precision_score as sk_precision,
|
||||
recall_score as sk_recall,
|
||||
f1_score as sk_f1_score,
|
||||
fbeta_score as sk_fbeta_score,
|
||||
confusion_matrix as sk_confusion_matrix,
|
||||
)
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.metrics.functional.classification import (
|
||||
to_onehot,
|
||||
to_categorical,
|
||||
get_num_classes,
|
||||
stat_scores,
|
||||
stat_scores_multiple_classes,
|
||||
accuracy,
|
||||
confusion_matrix,
|
||||
precision,
|
||||
recall,
|
||||
fbeta_score,
|
||||
f1_score,
|
||||
_binary_clf_curve,
|
||||
dice_score,
|
||||
average_precision,
|
||||
auroc,
|
||||
precision_recall_curve,
|
||||
roc,
|
||||
auc,
|
||||
iou,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
|
||||
pytest.param(sk_accuracy, accuracy, id='accuracy'),
|
||||
pytest.param(partial(sk_precision, average='macro'), precision, id='precision'),
|
||||
pytest.param(partial(sk_recall, average='macro'), recall, id='recall'),
|
||||
pytest.param(partial(sk_f1_score, average='macro'), f1_score, id='f1_score'),
|
||||
pytest.param(partial(sk_fbeta_score, average='macro', beta=2), partial(fbeta_score, beta=2), id='fbeta_score'),
|
||||
pytest.param(sk_confusion_matrix, confusion_matrix, id='confusion_matrix')
|
||||
])
|
||||
def test_against_sklearn(sklearn_metric, torch_metric):
|
||||
"""Compare PL metrics to sklearn version."""
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# iterate over different label counts in predictions and target
|
||||
for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]:
|
||||
pred = torch.randint(n_cls_pred, (300,), device=device)
|
||||
target = torch.randint(n_cls_target, (300,), device=device)
|
||||
|
||||
sk_score = sklearn_metric(target.cpu().detach().numpy(),
|
||||
pred.cpu().detach().numpy())
|
||||
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
|
||||
pl_score = torch_metric(pred, target)
|
||||
assert torch.allclose(sk_score, pl_score)
|
||||
|
||||
|
||||
def test_onehot():
|
||||
test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
||||
expected = torch.stack([
|
||||
torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]),
|
||||
torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)])
|
||||
])
|
||||
|
||||
assert test_tensor.shape == (2, 5)
|
||||
assert expected.shape == (2, 10, 5)
|
||||
|
||||
onehot_classes = to_onehot(test_tensor, num_classes=10)
|
||||
onehot_no_classes = to_onehot(test_tensor)
|
||||
|
||||
assert torch.allclose(onehot_classes, onehot_no_classes)
|
||||
|
||||
assert onehot_classes.shape == expected.shape
|
||||
assert onehot_no_classes.shape == expected.shape
|
||||
|
||||
assert torch.allclose(expected.to(onehot_no_classes), onehot_no_classes)
|
||||
assert torch.allclose(expected.to(onehot_classes), onehot_classes)
|
||||
|
||||
|
||||
def test_to_categorical():
|
||||
test_tensor = torch.stack([
|
||||
torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]),
|
||||
torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)])
|
||||
]).to(torch.float)
|
||||
|
||||
expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
||||
assert expected.shape == (2, 5)
|
||||
assert test_tensor.shape == (2, 10, 5)
|
||||
|
||||
result = to_categorical(test_tensor)
|
||||
|
||||
assert result.shape == expected.shape
|
||||
assert torch.allclose(result, expected.to(result.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'num_classes', 'expected_num_classes'], [
|
||||
pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10),
|
||||
pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
|
||||
pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
|
||||
])
|
||||
def test_get_num_classes(pred, target, num_classes, expected_num_classes):
|
||||
assert get_num_classes(pred, target, num_classes) == expected_num_classes
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp',
|
||||
'expected_tn', 'expected_fn', 'expected_support'], [
|
||||
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2),
|
||||
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2)
|
||||
])
|
||||
def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
|
||||
tp, fp, tn, fn, sup = stat_scores(pred, target, class_index=4)
|
||||
|
||||
assert tp.item() == expected_tp
|
||||
assert fp.item() == expected_fp
|
||||
assert tn.item() == expected_tn
|
||||
assert fn.item() == expected_fn
|
||||
assert sup.item() == expected_support
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp',
|
||||
'expected_tn', 'expected_fn', 'expected_support'], [
|
||||
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]),
|
||||
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]),
|
||||
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]),
|
||||
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2])
|
||||
])
|
||||
def test_stat_scores_multiclass(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
|
||||
tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target)
|
||||
|
||||
assert torch.allclose(torch.tensor(expected_tp).to(tp), tp)
|
||||
assert torch.allclose(torch.tensor(expected_fp).to(fp), fp)
|
||||
assert torch.allclose(torch.tensor(expected_tn).to(tn), tn)
|
||||
assert torch.allclose(torch.tensor(expected_fn).to(fn), fn)
|
||||
assert torch.allclose(torch.tensor(expected_support).to(sup), sup)
|
||||
|
||||
|
||||
def test_multilabel_accuracy():
|
||||
# Dense label indicator matrix format
|
||||
y1 = torch.tensor([[0, 1, 1], [1, 0, 1]])
|
||||
y2 = torch.tensor([[0, 0, 1], [1, 0, 1]])
|
||||
|
||||
assert torch.allclose(accuracy(y1, y2, reduction='none'), torch.tensor([2 / 3, 1.]))
|
||||
assert torch.allclose(accuracy(y1, y1, reduction='none'), torch.tensor([1., 1.]))
|
||||
assert torch.allclose(accuracy(y2, y2, reduction='none'), torch.tensor([1., 1.]))
|
||||
assert torch.allclose(accuracy(y2, torch.logical_not(y2), reduction='none'), torch.tensor([0., 0.]))
|
||||
assert torch.allclose(accuracy(y1, torch.logical_not(y1), reduction='none'), torch.tensor([0., 0.]))
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
accuracy(y2, torch.zeros_like(y2), reduction='none')
|
||||
|
||||
|
||||
def test_accuracy():
|
||||
pred = torch.tensor([0, 1, 2, 3])
|
||||
target = torch.tensor([0, 1, 2, 2])
|
||||
acc = accuracy(pred, target)
|
||||
|
||||
assert acc.item() == 0.75
|
||||
|
||||
pred = torch.tensor([0, 1, 2, 2])
|
||||
target = torch.tensor([0, 1, 1, 3])
|
||||
acc = accuracy(pred, target)
|
||||
|
||||
assert acc.item() == 0.50
|
||||
|
||||
|
||||
def test_confusion_matrix():
|
||||
target = (torch.arange(120) % 3).view(-1, 1)
|
||||
pred = target.clone()
|
||||
cm = confusion_matrix(pred, target, normalize=True)
|
||||
|
||||
assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]))
|
||||
|
||||
pred = torch.zeros_like(pred)
|
||||
cm = confusion_matrix(pred, target, normalize=True)
|
||||
assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]]))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_prec', 'expected_rec'], [
|
||||
pytest.param(torch.tensor([1., 0., 1., 0.]), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]),
|
||||
pytest.param(to_onehot(torch.tensor([1., 0., 1., 0.])), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5])
|
||||
])
|
||||
def test_precision_recall(pred, target, expected_prec, expected_rec):
|
||||
prec = precision(pred, target, reduction='none')
|
||||
rec = recall(pred, target, reduction='none')
|
||||
|
||||
assert torch.allclose(torch.tensor(expected_prec).to(prec), prec)
|
||||
assert torch.allclose(torch.tensor(expected_rec).to(rec), rec)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'beta', 'exp_score'], [
|
||||
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 0.5, [0.5, 0.5]),
|
||||
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 1, [0.5, 0.5]),
|
||||
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 2, [0.5, 0.5]),
|
||||
])
|
||||
def test_fbeta_score(pred, target, beta, exp_score):
|
||||
score = fbeta_score(torch.tensor(pred), torch.tensor(target), beta, reduction='none')
|
||||
assert torch.allclose(score, torch.tensor(exp_score))
|
||||
|
||||
score = fbeta_score(to_onehot(torch.tensor(pred)), torch.tensor(target), beta, reduction='none')
|
||||
assert torch.allclose(score, torch.tensor(exp_score))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'exp_score'], [
|
||||
pytest.param([0., 0., 0., 0.], [1., 1., 1., 1.], [0.0, 0.0]),
|
||||
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], [0.5, 0.5]),
|
||||
pytest.param([1., 0., 1., 0.], [1., 0., 1., 0.], [1.0, 1.0]),
|
||||
])
|
||||
def test_f1_score(pred, target, exp_score):
|
||||
score = f1_score(torch.tensor(pred), torch.tensor(target), reduction='none')
|
||||
assert torch.allclose(score, torch.tensor(exp_score))
|
||||
|
||||
score = f1_score(to_onehot(torch.tensor(pred)), torch.tensor(target), reduction='none')
|
||||
assert torch.allclose(score, torch.tensor(exp_score))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [
|
||||
pytest.param(1, 1., 42),
|
||||
pytest.param(None, 1., 42),
|
||||
])
|
||||
def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
|
||||
# TODO: move back the pred and target to test func arguments
|
||||
# if you fix the array inside the function, you'd also have fix the shape,
|
||||
# because when the array changes, you also have to fix the shape
|
||||
seed_everything(0)
|
||||
pred = torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100
|
||||
target = torch.tensor([0, 1] * 50, dtype=torch.int)
|
||||
if sample_weight is not None:
|
||||
sample_weight = torch.ones_like(pred) * sample_weight
|
||||
|
||||
fps, tps, thresh = _binary_clf_curve(pred, target, sample_weight, pos_label)
|
||||
|
||||
assert isinstance(tps, torch.Tensor)
|
||||
assert isinstance(fps, torch.Tensor)
|
||||
assert isinstance(thresh, torch.Tensor)
|
||||
assert tps.shape == (exp_shape,)
|
||||
assert fps.shape == (exp_shape,)
|
||||
assert thresh.shape == (exp_shape,)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], [
|
||||
pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4])
|
||||
])
|
||||
def test_pr_curve(pred, target, expected_p, expected_r, expected_t):
|
||||
p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target))
|
||||
assert p.size() == r.size()
|
||||
assert p.size(0) == t.size(0) + 1
|
||||
|
||||
assert torch.allclose(p, torch.tensor(expected_p).to(p))
|
||||
assert torch.allclose(r, torch.tensor(expected_r).to(r))
|
||||
assert torch.allclose(t, torch.tensor(expected_t).to(t))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [
|
||||
pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]),
|
||||
pytest.param([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]),
|
||||
pytest.param([1, 1], [1, 0], [0, 1], [0, 1]),
|
||||
pytest.param([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]),
|
||||
pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]),
|
||||
])
|
||||
def test_roc_curve(pred, target, expected_tpr, expected_fpr):
|
||||
fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target))
|
||||
|
||||
assert fpr.shape == tpr.shape
|
||||
assert fpr.size(0) == thresh.size(0)
|
||||
assert torch.allclose(fpr, torch.tensor(expected_fpr).to(fpr))
|
||||
assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||
pytest.param([0, 1, 0, 1], [0, 1, 0, 1], 1.),
|
||||
pytest.param([1, 1, 0, 0], [0, 0, 1, 1], 0.),
|
||||
pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.5),
|
||||
pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 1.),
|
||||
pytest.param([0.5, 0.5, 0.5, 0.5], [1, 1, 0, 0], 0.5),
|
||||
])
|
||||
def test_auroc(pred, target, expected):
|
||||
score = auroc(torch.tensor(pred), torch.tensor(target)).item()
|
||||
assert score == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['x', 'y', 'expected'], [
|
||||
pytest.param([0, 1], [0, 1], 0.5),
|
||||
pytest.param([1, 0], [0, 1], 0.5),
|
||||
pytest.param([1, 0, 0], [0, 1, 1], 0.5),
|
||||
pytest.param([0, 1], [1, 1], 1),
|
||||
pytest.param([0, 0.5, 1], [0, 0.5, 1], 0.5),
|
||||
])
|
||||
def test_auc(x, y, expected):
|
||||
# Test Area Under Curve (AUC) computation
|
||||
assert auc(torch.tensor(x), torch.tensor(y)) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['scores', 'target', 'expected_score'], [
|
||||
# Check the average_precision_score of a constant predictor is
|
||||
# the TPR
|
||||
# Generate a dataset with 25% of positives
|
||||
# And a constant score
|
||||
# The precision is then the fraction of positive whatever the recall
|
||||
# is, as there is only one threshold:
|
||||
pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25),
|
||||
# With threshold 0.8 : 1 TP and 2 TN and one FN
|
||||
pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75),
|
||||
])
|
||||
def test_average_precision(scores, target, expected_score):
|
||||
assert average_precision(scores, target) == expected_score
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||
pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.),
|
||||
pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.),
|
||||
pytest.param([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3),
|
||||
pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.),
|
||||
])
|
||||
def test_dice_score(pred, target, expected):
|
||||
score = dice_score(torch.tensor(pred), torch.tensor(target))
|
||||
assert score == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['half_ones', 'reduction', 'remove_bg', 'expected'], [
|
||||
pytest.param(False, 'none', False, torch.Tensor([1, 1, 1])),
|
||||
pytest.param(False, 'elementwise_mean', False, torch.Tensor([1])),
|
||||
pytest.param(False, 'none', True, torch.Tensor([1, 1])),
|
||||
pytest.param(True, 'none', False, torch.Tensor([0.5, 0.5, 0.5])),
|
||||
pytest.param(True, 'elementwise_mean', False, torch.Tensor([0.5])),
|
||||
pytest.param(True, 'none', True, torch.Tensor([0.5, 0.5])),
|
||||
])
|
||||
def test_iou(half_ones, reduction, remove_bg, expected):
|
||||
pred = (torch.arange(120) % 3).view(-1, 1)
|
||||
target = (torch.arange(120) % 3).view(-1, 1)
|
||||
if half_ones:
|
||||
pred[:60] = 1
|
||||
iou_val = iou(pred, target, remove_bg=remove_bg, reduction=reduction)
|
||||
assert torch.allclose(iou_val, expected, atol=1e-9)
|
||||
|
||||
|
||||
# example data taken from
|
||||
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
|
|
@ -0,0 +1,66 @@
|
|||
import pytest
|
||||
import torch
|
||||
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu
|
||||
|
||||
from pytorch_lightning.metrics.functional.nlp import bleu_score
|
||||
|
||||
# example taken from
|
||||
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.sentence_bleu
|
||||
HYPOTHESIS1 = tuple(
|
||||
"It is a guide to action which ensures that the military always obeys the commands of the party".split()
|
||||
)
|
||||
REFERENCE1 = tuple("It is a guide to action that ensures that the military will forever heed Party commands".split())
|
||||
REFERENCE2 = tuple(
|
||||
"It is a guiding principle which makes the military forces always being under the command of the Party".split()
|
||||
)
|
||||
REFERENCE3 = tuple("It is the practical guide for the army always to heed the directions of the party".split())
|
||||
|
||||
|
||||
# example taken from
|
||||
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu
|
||||
HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split()
|
||||
HYP2 = "he read the book because he was interested in world history".split()
|
||||
|
||||
REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split()
|
||||
REF1B = "It is a guiding principle which makes the military force always being under the command of the Party".split()
|
||||
REF1C = "It is the practical guide for the army always to heed the directions of the party".split()
|
||||
REF2A = "he was interested in world history because he read the book".split()
|
||||
|
||||
LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]]
|
||||
HYPOTHESES = [HYP1, HYP2]
|
||||
|
||||
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction
|
||||
smooth_func = SmoothingFunction().method2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["weights", "n_gram", "smooth_func", "smooth"],
|
||||
[
|
||||
pytest.param([1], 1, None, False),
|
||||
pytest.param([0.5, 0.5], 2, smooth_func, True),
|
||||
pytest.param([0.333333, 0.333333, 0.333333], 3, None, False),
|
||||
pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True),
|
||||
],
|
||||
)
|
||||
def test_bleu_score(weights, n_gram, smooth_func, smooth):
|
||||
nltk_output = sentence_bleu(
|
||||
[REFERENCE1, REFERENCE2, REFERENCE3], HYPOTHESIS1, weights=weights, smoothing_function=smooth_func
|
||||
)
|
||||
pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth)
|
||||
assert torch.allclose(pl_output, torch.tensor(nltk_output))
|
||||
|
||||
nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func)
|
||||
pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth)
|
||||
assert torch.allclose(pl_output, torch.tensor(nltk_output))
|
||||
|
||||
|
||||
def test_bleu_empty():
|
||||
hyp = [[]]
|
||||
ref = [[[]]]
|
||||
assert bleu_score(hyp, ref) == torch.tensor(0.0)
|
||||
|
||||
|
||||
def test_no_4_gram():
|
||||
hyps = [["My", "full", "pytorch-lightning"]]
|
||||
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]]
|
||||
assert bleu_score(hyps, refs) == torch.tensor(0.0)
|
|
@ -0,0 +1,15 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.reduction import reduce
|
||||
|
||||
|
||||
def test_reduce():
|
||||
start_tensor = torch.rand(50, 40, 30)
|
||||
|
||||
assert torch.allclose(reduce(start_tensor, 'elementwise_mean'), torch.mean(start_tensor))
|
||||
assert torch.allclose(reduce(start_tensor, 'sum'), torch.sum(start_tensor))
|
||||
assert torch.allclose(reduce(start_tensor, 'none'), start_tensor)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
reduce(start_tensor, 'error_reduction')
|
|
@ -0,0 +1,144 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from skimage.metrics import peak_signal_noise_ratio as ski_psnr
|
||||
from skimage.metrics import structural_similarity as ski_ssim
|
||||
|
||||
from pytorch_lightning.metrics.functional import (
|
||||
mae,
|
||||
mse,
|
||||
psnr,
|
||||
rmse,
|
||||
rmsle,
|
||||
ssim
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.25),
|
||||
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 3.0),
|
||||
])
|
||||
def test_mse(pred, target, expected):
|
||||
score = mse(torch.tensor(pred), torch.tensor(target))
|
||||
assert score.item() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||
pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0),
|
||||
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.5),
|
||||
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 1.7321),
|
||||
])
|
||||
def test_rmse(pred, target, expected):
|
||||
score = rmse(torch.tensor(pred), torch.tensor(target))
|
||||
assert torch.allclose(score, torch.tensor(expected), atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||
pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0),
|
||||
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.25),
|
||||
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 1.5),
|
||||
])
|
||||
def test_mae(pred, target, expected):
|
||||
score = mae(torch.tensor(pred), torch.tensor(target))
|
||||
assert score.item() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||
pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0),
|
||||
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.0207),
|
||||
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 0.2841),
|
||||
])
|
||||
def test_rmsle(pred, target, expected):
|
||||
score = rmsle(torch.tensor(pred), torch.tensor(target))
|
||||
assert torch.allclose(score, torch.tensor(expected), atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target'], [
|
||||
pytest.param([0., 1., 2., 3.], [0., 1., 2., 3.]),
|
||||
pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.]),
|
||||
pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.]),
|
||||
])
|
||||
def test_psnr_with_skimage(pred, target):
|
||||
score = psnr(pred=torch.tensor(pred),
|
||||
target=torch.tensor(target))
|
||||
sk_score = ski_psnr(np.array(pred), np.array(target), data_range=3)
|
||||
assert torch.allclose(score, torch.tensor(sk_score, dtype=torch.float), atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target'], [
|
||||
pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.]),
|
||||
pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.]),
|
||||
])
|
||||
def test_psnr_base_e_wider_range(pred, target):
|
||||
score = psnr(pred=torch.tensor(pred),
|
||||
target=torch.tensor(target),
|
||||
data_range=4,
|
||||
base=2.718281828459045)
|
||||
sk_score = ski_psnr(np.array(pred), np.array(target), data_range=4) * np.log(10)
|
||||
assert torch.allclose(score, torch.tensor(sk_score, dtype=torch.float32), atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
|
||||
pytest.param(ski_psnr, psnr, id='peak_signal_noise_ratio')
|
||||
])
|
||||
def test_psnr_against_sklearn(sklearn_metric, torch_metric):
|
||||
"""Compare PL metrics to sklearn version."""
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]:
|
||||
pred = torch.randint(n_cls_pred, (500,), device=device, dtype=torch.float)
|
||||
target = torch.randint(n_cls_target, (500,), device=device, dtype=torch.float)
|
||||
|
||||
sk_score = sklearn_metric(target.cpu().detach().numpy(),
|
||||
pred.cpu().detach().numpy(),
|
||||
data_range=n_cls_target)
|
||||
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
|
||||
pl_score = torch_metric(pred, target, data_range=n_cls_target)
|
||||
assert torch.allclose(sk_score, pl_score)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['size', 'channel', 'plus', 'multichannel'], [
|
||||
pytest.param(16, 1, 0.125, False),
|
||||
pytest.param(32, 1, 0.25, False),
|
||||
pytest.param(48, 3, 0.5, True),
|
||||
pytest.param(64, 4, 0.75, True),
|
||||
pytest.param(128, 5, 1, True)
|
||||
])
|
||||
def test_ssim(size, channel, plus, multichannel):
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
pred = torch.rand(1, channel, size, size, device=device)
|
||||
target = pred + plus
|
||||
ssim_idx = ssim(pred, target)
|
||||
np_pred = np.random.rand(size, size, channel)
|
||||
if multichannel is False:
|
||||
np_pred = np_pred[:, :, 0]
|
||||
np_target = np.add(np_pred, plus)
|
||||
sk_ssim_idx = ski_ssim(np_pred, np_target, win_size=11, multichannel=multichannel, gaussian_weights=True)
|
||||
assert torch.allclose(ssim_idx, torch.tensor(sk_ssim_idx, dtype=torch.float, device=device), atol=1e-2, rtol=1e-2)
|
||||
|
||||
ssim_idx = ssim(pred, pred)
|
||||
assert torch.allclose(ssim_idx, torch.tensor(1.0, device=device))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'kernel', 'sigma'], [
|
||||
pytest.param([1, 1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # shape
|
||||
pytest.param([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape)
|
||||
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma)
|
||||
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma)
|
||||
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma)
|
||||
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input
|
||||
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input
|
||||
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input
|
||||
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input
|
||||
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input
|
||||
])
|
||||
def test_ssim_invalid_inputs(pred, target, kernel, sigma):
|
||||
pred_t = torch.rand(pred)
|
||||
target_t = torch.rand(target, dtype=torch.float64)
|
||||
with pytest.raises(TypeError):
|
||||
ssim(pred_t, target_t)
|
||||
|
||||
pred = torch.rand(pred)
|
||||
target = torch.rand(target)
|
||||
with pytest.raises(ValueError):
|
||||
ssim(pred, target, kernel, sigma)
|
Loading…
Reference in New Issue