Fix parameters and docs in metrics (#2473)
* Fix parameters and docs in metrics * doc improvements * whitespace * doc indentation * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * zero * drop defaults Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai>
This commit is contained in:
parent
1dc724239a
commit
d3f5717e81
|
@ -10,11 +10,11 @@ Metrics
|
|||
This is a general package for PyTorch Metrics. These can also be used with regular non-lightning PyTorch code.
|
||||
Metrics are used to monitor model performance.
|
||||
|
||||
In this package we provide two major pieces of functionality.
|
||||
In this package, we provide two major pieces of functionality.
|
||||
|
||||
1. A Metric class you can use to implement metrics with built-in distributed (ddp) support which are device agnostic.
|
||||
2. A collection of ready to use pupular metrics. There are two types of metrics: Class metrics and Functional metrics.
|
||||
3. A interface to call `sklearns metrics <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_
|
||||
2. A collection of ready to use popular metrics. There are two types of metrics: Class metrics and Functional metrics.
|
||||
3. An interface to call `sklearns metrics <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_
|
||||
|
||||
Example::
|
||||
|
||||
|
@ -28,14 +28,13 @@ Example::
|
|||
|
||||
.. warning::
|
||||
The metrics package is still in development! If we're missing a metric or you find a mistake, please send a PR!
|
||||
to a few metrics. Please feel free to create an issue/PR if you have a proposed
|
||||
metric or have found a bug.
|
||||
to a few metrics. Please feel free to create an issue/PR if you have a proposed metric or have found a bug.
|
||||
|
||||
----------------
|
||||
|
||||
Implement a metric
|
||||
------------------
|
||||
You can implement metrics as either a PyTorch metric or a Numpy metric (It is recommend to use PyTorch metrics when possible,
|
||||
You can implement metrics as either a PyTorch metric or a Numpy metric (It is recommended to use PyTorch metrics when possible,
|
||||
since Numpy metrics slow down training).
|
||||
|
||||
Use :class:`TensorMetric` to implement native PyTorch metrics. This class
|
||||
|
@ -403,8 +402,8 @@ To use the sklearn backend of metrics simply import as
|
|||
val = metric(pred, target)
|
||||
|
||||
Each converted sklearn metric comes has the same interface as its
|
||||
originally counterpart (e.g. accuracy takes the additional `normalize` keyword).
|
||||
Like the native Lightning metrics these converted sklearn metrics also come
|
||||
original counterpart (e.g. accuracy takes the additional `normalize` keyword).
|
||||
Like the native Lightning metrics, these converted sklearn metrics also come
|
||||
with built-in distributed (ddp) support.
|
||||
|
||||
SklearnMetric (sk)
|
||||
|
|
|
@ -11,15 +11,14 @@ from pytorch_lightning.utilities import rank_zero_warn
|
|||
|
||||
def to_onehot(
|
||||
tensor: torch.Tensor,
|
||||
n_classes: Optional[int] = None,
|
||||
num_classes: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts a dense label tensor to one-hot format
|
||||
|
||||
Args:
|
||||
tensor: dense label tensor, with shape [N, d1, d2, ...]
|
||||
|
||||
n_classes: number of classes C
|
||||
num_classes: number of classes C
|
||||
|
||||
Output:
|
||||
A sparse label tensor with shape [N, C, d1, d2, ...]
|
||||
|
@ -33,22 +32,25 @@ def to_onehot(
|
|||
[0, 0, 0, 1]])
|
||||
|
||||
"""
|
||||
if n_classes is None:
|
||||
n_classes = int(tensor.max().detach().item() + 1)
|
||||
if num_classes is None:
|
||||
num_classes = int(tensor.max().detach().item() + 1)
|
||||
dtype, device, shape = tensor.dtype, tensor.device, tensor.shape
|
||||
tensor_onehot = torch.zeros(shape[0], n_classes, *shape[1:],
|
||||
tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:],
|
||||
dtype=dtype, device=device)
|
||||
index = tensor.long().unsqueeze(1).expand_as(tensor_onehot)
|
||||
return tensor_onehot.scatter_(1, index, 1.0)
|
||||
|
||||
|
||||
def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
|
||||
def to_categorical(
|
||||
tensor: torch.Tensor,
|
||||
argmax_dim: int = 1
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts a tensor of probabilities to a dense label tensor
|
||||
|
||||
Args:
|
||||
tensor: probabilities to get the categorical label [N, d1, d2, ...]
|
||||
argmax_dim: dimension to apply (default: 1)
|
||||
argmax_dim: dimension to apply
|
||||
|
||||
Return:
|
||||
A tensor with categorical labels [N, d2, ...]
|
||||
|
@ -66,15 +68,15 @@ def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
|
|||
def get_num_classes(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: Optional[int],
|
||||
num_classes: Optional[int] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Returns the number of classes for a given prediction and target tensor.
|
||||
Calculates the number of classes for a given prediction and target tensor.
|
||||
|
||||
Args:
|
||||
pred: predicted values
|
||||
target: true labels
|
||||
num_classes: number of classes if known (default: None)
|
||||
num_classes: number of classes if known
|
||||
|
||||
Return:
|
||||
An integer that represents the number of classes.
|
||||
|
@ -108,7 +110,7 @@ def stat_scores(
|
|||
axis the argmax transformation will be applied over
|
||||
|
||||
Return:
|
||||
True Positive, False Positive, True Negative, False Negative
|
||||
True Positive, False Positive, True Negative, False Negative, Support
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -145,12 +147,12 @@ def stat_scores_multiple_classes(
|
|||
Args:
|
||||
pred: prediction tensor
|
||||
target: target tensor
|
||||
class_index: class to calculate over
|
||||
num_classes: number of classes if known
|
||||
argmax_dim: if pred is a tensor of probabilities, this indicates the
|
||||
axis the argmax transformation will be applied over
|
||||
|
||||
Return:
|
||||
True Positive, False Positive, True Negative, False Negative
|
||||
True Positive, False Positive, True Negative, False Negative, Support
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -199,11 +201,11 @@ def accuracy(
|
|||
target: ground truth labels
|
||||
num_classes: number of classes
|
||||
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||
Available reduction methods:
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
A Tensor with the classification score.
|
||||
|
@ -282,11 +284,11 @@ def precision_recall(
|
|||
target: ground-truth labels
|
||||
num_classes: number of classes
|
||||
reduction: method for reducing precision-recall values (default: takes the mean)
|
||||
Available reduction methods:
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
Tensor with precision and recall
|
||||
|
@ -331,11 +333,11 @@ def precision(
|
|||
target: ground-truth labels
|
||||
num_classes: number of classes
|
||||
reduction: method for reducing precision values (default: takes the mean)
|
||||
Available reduction methods:
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
Tensor with precision.
|
||||
|
@ -366,11 +368,11 @@ def recall(
|
|||
target: ground-truth labels
|
||||
num_classes: number of classes
|
||||
reduction: method for reducing recall values (default: takes the mean)
|
||||
Available reduction methods:
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
Tensor with recall.
|
||||
|
@ -407,11 +409,11 @@ def fbeta_score(
|
|||
beta -> inf: only recall
|
||||
num_classes: number of classes
|
||||
reduction: method for reducing F-score (default: takes the mean)
|
||||
Available reduction methods:
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements.
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements.
|
||||
|
||||
Return:
|
||||
Tensor with the value of F-score. It is a value between 0-1.
|
||||
|
@ -452,11 +454,11 @@ def f1_score(
|
|||
target: ground-truth labels
|
||||
num_classes: number of classes
|
||||
reduction: method for reducing F1-score (default: takes the mean)
|
||||
Available reduction methods:
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements.
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements.
|
||||
|
||||
Return:
|
||||
Tensor containing F1-score
|
||||
|
@ -510,7 +512,6 @@ def _binary_clf_curve(
|
|||
# express fps as a cumsum to ensure fps is increasing even in
|
||||
# the presence of floating point errors
|
||||
fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs]
|
||||
|
||||
else:
|
||||
fps = 1 + threshold_idxs - tps
|
||||
|
||||
|
@ -530,7 +531,7 @@ def roc(
|
|||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
sample_weight: sample weights
|
||||
pos_label: the label for the positive class (default: 1)
|
||||
pos_label: the label for the positive class
|
||||
|
||||
Return:
|
||||
false-positive rate (fpr), true-positive rate (tpr), thresholds
|
||||
|
@ -628,7 +629,7 @@ def precision_recall_curve(
|
|||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
sample_weight: sample weights
|
||||
pos_label: the label for the positive class (default: 1.)
|
||||
pos_label: the label for the positive class
|
||||
|
||||
Return:
|
||||
precision, recall, thresholds
|
||||
|
@ -722,14 +723,18 @@ def multiclass_precision_recall_curve(
|
|||
return tuple(class_pr_vals)
|
||||
|
||||
|
||||
def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True) -> torch.Tensor:
|
||||
def auc(
|
||||
x: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
reorder: bool = True
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes Area Under the Curve (AUC) using the trapezoidal rule
|
||||
|
||||
Args:
|
||||
x: x-coordinates
|
||||
y: y-coordinates
|
||||
reorder: reorder coordinates, so they are increasing.
|
||||
reorder: reorder coordinates, so they are increasing
|
||||
|
||||
Return:
|
||||
Tensor containing AUC score (float)
|
||||
|
@ -800,7 +805,10 @@ def auroc(
|
|||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
sample_weight: sample weights
|
||||
pos_label: the label for the positive class (default: 1.)
|
||||
pos_label: the label for the positive class
|
||||
|
||||
Return:
|
||||
Tensor containing ROCAUC score
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -824,12 +832,16 @@ def average_precision(
|
|||
pos_label: int = 1.,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute average precision from prediction scores
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
sample_weight: sample weights
|
||||
pos_label: the label for the positive class (default: 1.)
|
||||
pos_label: the label for the positive class
|
||||
|
||||
Return:
|
||||
Tensor containing average precision score
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -856,11 +868,13 @@ def dice_score(
|
|||
reduction: str = 'elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute dice score from prediction scores
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
bg: whether to also compute dice for the background
|
||||
nan_score: score to return, if a NaN occurs during computation (denom zero)
|
||||
nan_score: score to return, if a NaN occurs during computation
|
||||
no_fg_score: score to return, if no foreground pixel was found in target
|
||||
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
@ -869,6 +883,9 @@ def dice_score(
|
|||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
Tensor containing dice score
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||
|
@ -880,10 +897,10 @@ def dice_score(
|
|||
tensor(0.2500)
|
||||
|
||||
"""
|
||||
n_classes = pred.shape[1]
|
||||
num_classes = pred.shape[1]
|
||||
bg = (1 - int(bool(bg)))
|
||||
scores = torch.zeros(n_classes - bg, device=pred.device, dtype=torch.float32)
|
||||
for i in range(bg, n_classes):
|
||||
scores = torch.zeros(num_classes - bg, device=pred.device, dtype=torch.float32)
|
||||
for i in range(bg, num_classes):
|
||||
if not (target == i).any():
|
||||
# no foreground class
|
||||
scores[i - bg] += no_fg_score
|
||||
|
@ -903,31 +920,32 @@ def dice_score(
|
|||
return reduce(scores, reduction=reduction)
|
||||
|
||||
|
||||
def iou(pred: torch.Tensor, target: torch.Tensor,
|
||||
num_classes: Optional[int] = None, remove_bg: bool = False,
|
||||
reduction: str = 'elementwise_mean'):
|
||||
def iou(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
remove_bg: bool = False,
|
||||
reduction: str = 'elementwise_mean'
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Intersection over union, or Jaccard index calculation.
|
||||
|
||||
Args:
|
||||
pred: Tensor containing predictions
|
||||
|
||||
target: Tensor containing targets
|
||||
|
||||
num_classes: Optionally specify the number of classes
|
||||
|
||||
remove_bg: Flag to state whether a background class has been included
|
||||
within input parameters. If true, will remove background class. If
|
||||
false, return IoU over all classes.
|
||||
false, return IoU over all classes
|
||||
Assumes that background is '0' class in input tensor
|
||||
|
||||
reduction: a method for reducing IoU over labels (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Returns:
|
||||
Return:
|
||||
IoU score : Tensor containing single value if reduction is
|
||||
'elementwise_mean', or number of classes if reduction is 'none'
|
||||
|
||||
|
|
|
@ -278,7 +278,7 @@ class ConfusionMatrix(SklearnMetric):
|
|||
y_true: Ground truth (correct) target values.
|
||||
|
||||
Return:
|
||||
Confusion matrix (array of shape [n_classes, n_classes])
|
||||
Confusion matrix (array of shape [num_classes, num_classes])
|
||||
|
||||
"""
|
||||
return super().forward(y_pred=y_pred, y_true=y_true)
|
||||
|
|
|
@ -103,7 +103,7 @@ def test_onehot():
|
|||
assert test_tensor.shape == (2, 5)
|
||||
assert expected.shape == (2, 10, 5)
|
||||
|
||||
onehot_classes = to_onehot(test_tensor, n_classes=10)
|
||||
onehot_classes = to_onehot(test_tensor, num_classes=10)
|
||||
onehot_no_classes = to_onehot(test_tensor)
|
||||
|
||||
assert torch.allclose(onehot_classes, onehot_no_classes)
|
||||
|
|
Loading…
Reference in New Issue