lightning/pytorch_lightning/metrics/classification.py

819 lines
26 KiB
Python

from typing import Any, Optional, Sequence, Tuple
import torch
from pytorch_lightning.metrics.functional.classification import (
accuracy,
confusion_matrix,
precision_recall_curve,
precision,
recall,
average_precision,
auroc,
fbeta_score,
f1_score,
roc,
multiclass_roc,
multiclass_precision_recall_curve,
dice_score,
iou,
)
from pytorch_lightning.metrics.metric import TensorMetric, TensorCollectionMetric
class Accuracy(TensorMetric):
"""
Computes the accuracy classification score
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = Accuracy()
>>> metric(pred, target)
tensor(0.7500)
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='accuracy',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.num_classes = num_classes
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the classification score.
"""
return accuracy(pred=pred, target=target,
num_classes=self.num_classes, reduction=self.reduction)
class ConfusionMatrix(TensorMetric):
"""
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
in group i that were predicted in group j.
Example:
>>> pred = torch.tensor([0, 1, 2, 2])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = ConfusionMatrix()
>>> metric(pred, target)
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 2.]])
"""
def __init__(
self,
normalize: bool = False,
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
normalize: whether to compute a normalized confusion matrix
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='confusion_matrix',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.normalize = normalize
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the confusion matrix.
"""
return confusion_matrix(pred=pred, target=target,
normalize=self.normalize)
class PrecisionRecall(TensorCollectionMetric):
"""
Computes the precision recall curve
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = PrecisionRecall()
>>> prec, recall, thr = metric(pred, target)
>>> prec
tensor([0.3333, 0.0000, 0.0000, 1.0000])
>>> recall
tensor([1., 0., 0., 0.])
>>> thr
tensor([1., 2., 3.])
"""
def __init__(
self,
pos_label: int = 1,
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
pos_label: positive label indicator
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='precision_recall_curve',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.pos_label = pos_label
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
sample_weight: the weights per sample
Return:
- precision values
- recall values
- threshold values
"""
return precision_recall_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=self.pos_label)
class Precision(TensorMetric):
"""
Computes the precision score
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = Precision(num_classes=4)
>>> metric(pred, target)
tensor(0.7500)
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='precision',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.num_classes = num_classes
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the classification score.
"""
return precision(pred=pred, target=target,
num_classes=self.num_classes,
reduction=self.reduction)
class Recall(TensorMetric):
"""
Computes the recall score
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = Recall()
>>> metric(pred, target)
tensor(0.6250)
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='recall',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.num_classes = num_classes
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the classification score.
"""
return recall(pred=pred,
target=target,
num_classes=self.num_classes,
reduction=self.reduction)
class AveragePrecision(TensorMetric):
"""
Computes the average precision score
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = AveragePrecision()
>>> metric(pred, target)
tensor(0.3333)
"""
def __init__(
self,
pos_label: int = 1,
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
pos_label: positive label indicator
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='AP',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.pos_label = pos_label
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None
) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
sample_weight: the weights per sample
Return:
torch.Tensor: classification score
"""
return average_precision(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=self.pos_label)
class AUROC(TensorMetric):
"""
Computes the area under curve (AUC) of the receiver operator characteristic (ROC)
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = AUROC()
>>> metric(pred, target)
tensor(0.3333)
"""
def __init__(
self,
pos_label: int = 1,
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
pos_label: positive label indicator
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='auroc',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.pos_label = pos_label
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None
) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
sample_weight: the weights per sample
Return:
torch.Tensor: classification score
"""
return auroc(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=self.pos_label)
class FBeta(TensorMetric):
"""
Computes the FBeta Score, which is the weighted harmonic mean of precision and recall.
It ranges between 1 and 0, where 1 is perfect and the worst value is 0.
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = FBeta(0.25)
>>> metric(pred, target)
tensor(0.7361)
"""
def __init__(
self,
beta: float,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
beta: determines the weight of recall in the combined score.
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for DDP reduction
"""
super().__init__(name='fbeta',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.beta = beta
self.num_classes = num_classes
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
Return:
torch.Tensor: classification score
"""
return fbeta_score(pred=pred, target=target,
beta=self.beta, num_classes=self.num_classes,
reduction=self.reduction)
class F1(TensorMetric):
"""
Computes the F1 score, which is the harmonic mean of the precision and recall.
It ranges between 1 and 0, where 1 is perfect and the worst value is 0.
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = F1()
>>> metric(pred, target)
tensor(0.6667)
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='f1',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.num_classes = num_classes
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
Return:
torch.Tensor: classification score
"""
return f1_score(pred=pred, target=target,
num_classes=self.num_classes,
reduction=self.reduction)
class ROC(TensorCollectionMetric):
"""
Computes the Receiver Operator Characteristic (ROC)
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = ROC()
>>> fps, tps, thresholds = metric(pred, target)
>>> fps
tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000])
>>> tps
tensor([0., 0., 0., 1., 1.])
>>> thresholds
tensor([4., 3., 2., 1., 0.])
"""
def __init__(
self,
pos_label: int = 1,
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
pos_label: positive label indicator
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='roc',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.pos_label = pos_label
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
sample_weight: the weights per sample
Return:
- false positive rate
- true positive rate
- thresholds
"""
return roc(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=self.pos_label)
class MulticlassROC(TensorCollectionMetric):
"""
Computes the multiclass ROC
Example:
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> metric = MulticlassROC()
>>> classes_roc = metric(pred, target)
>>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE
((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])))
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='multiclass_roc',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.num_classes = num_classes
def forward(
self, pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Actual metric computation
Args:
pred: predicted probability for each label
target: groundtruth labels
sample_weight: Weights for each sample defining the sample's impact on the score
Return:
tuple: A tuple consisting of one tuple per class, holding false positive rate, true positive rate and thresholds
"""
return multiclass_roc(pred=pred,
target=target,
sample_weight=sample_weight,
num_classes=self.num_classes)
class MulticlassPrecisionRecall(TensorCollectionMetric):
"""Computes the multiclass PR Curve
Example:
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> metric = MulticlassPrecisionRecall()
>>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE
((tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])),
(tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])),
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])),
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])))
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='multiclass_precision_recall_curve',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.num_classes = num_classes
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Actual metric computation
Args:
pred: predicted probability for each label
target: groundtruth labels
sample_weight: Weights for each sample defining the sample's impact on the score
Return:
tuple: A tuple consisting of one tuple per class, holding precision, recall and thresholds
"""
return multiclass_precision_recall_curve(pred=pred,
target=target,
sample_weight=sample_weight,
num_classes=self.num_classes)
class DiceCoefficient(TensorMetric):
"""
Computes the dice coefficient
Example:
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> metric = DiceCoefficient()
>>> metric(pred, target)
tensor(0.3333)
"""
def __init__(
self,
include_background: bool = False,
nan_score: float = 0.0, no_fg_score: float = 0.0,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
include_background: whether to also compute dice for the background
nan_score: score to return, if a NaN occurs during computation (denom zero)
no_fg_score: score to return, if no foreground pixel was found in target
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='dice',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.include_background = include_background
self.nan_score = nan_score
self.no_fg_score = no_fg_score
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted probability for each label
target: groundtruth labels
Return:
torch.Tensor: the calculated dice coefficient
"""
return dice_score(pred=pred,
target=target,
bg=self.include_background,
nan_score=self.nan_score,
no_fg_score=self.no_fg_score,
reduction=self.reduction)
class IoU(TensorMetric):
"""
Computes the intersection over union.
Example:
>>> pred = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0],
... [0, 0, 1, 1, 1, 0, 0, 0],
... [0, 0, 0, 0, 0, 0, 0, 0]])
>>> target = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0],
... [0, 0, 0, 1, 1, 1, 0, 0],
... [0, 0, 0, 0, 0, 0, 0, 0]])
>>> metric = IoU()
>>> metric(pred, target)
tensor(0.7045)
"""
def __init__(self,
remove_bg: bool = False,
reduction: str = 'elementwise_mean'):
"""
Args:
remove_bg: Flag to state whether a background class has been included
within input parameters. If true, will remove background class. If
false, return IoU over all classes.
Assumes that background is '0' class in input tensor
reduction: a method for reducing IoU over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
"""
super().__init__(name='iou')
self.remove_bg = remove_bg
self.reduction = reduction
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor,
sample_weight: Optional[torch.Tensor] = None):
"""
Actual metric calculation.
"""
return iou(y_pred, y_true, remove_bg=self.remove_bg, reduction=self.reduction)