221 lines
8.7 KiB
Python
221 lines
8.7 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from typing import List, Optional, Sequence, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from pytorch_lightning.utilities import rank_zero_warn
|
|
|
|
|
|
def _binary_clf_curve(
|
|
preds: torch.Tensor,
|
|
target: torch.Tensor,
|
|
sample_weights: Optional[Sequence] = None,
|
|
pos_label: int = 1.,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py
|
|
"""
|
|
if sample_weights is not None and not isinstance(sample_weights, torch.Tensor):
|
|
sample_weights = torch.tensor(sample_weights, device=preds.device, dtype=torch.float)
|
|
|
|
# remove class dimension if necessary
|
|
if preds.ndim > target.ndim:
|
|
preds = preds[:, 0]
|
|
desc_score_indices = torch.argsort(preds, descending=True)
|
|
|
|
preds = preds[desc_score_indices]
|
|
target = target[desc_score_indices]
|
|
|
|
if sample_weights is not None:
|
|
weight = sample_weights[desc_score_indices]
|
|
else:
|
|
weight = 1.
|
|
|
|
# pred typically has many tied values. Here we extract
|
|
# the indices associated with the distinct values. We also
|
|
# concatenate a value for the end of the curve.
|
|
distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0]
|
|
threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1)
|
|
target = (target == pos_label).to(torch.long)
|
|
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]
|
|
|
|
if sample_weights is not None:
|
|
# express fps as a cumsum to ensure fps is increasing even in
|
|
# the presence of floating point errors
|
|
fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs]
|
|
else:
|
|
fps = 1 + threshold_idxs - tps
|
|
|
|
return fps, tps, preds[threshold_idxs]
|
|
|
|
|
|
def _precision_recall_curve_update(
|
|
preds: torch.Tensor,
|
|
target: torch.Tensor,
|
|
num_classes: Optional[int] = None,
|
|
pos_label: Optional[int] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
|
|
if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1):
|
|
raise ValueError(
|
|
"preds and target must have same number of dimensions, or one additional dimension for preds"
|
|
)
|
|
# single class evaluation
|
|
if len(preds.shape) == len(target.shape):
|
|
num_classes = 1
|
|
if pos_label is None:
|
|
rank_zero_warn('`pos_label` automatically set 1.')
|
|
pos_label = 1
|
|
preds = preds.flatten()
|
|
target = target.flatten()
|
|
|
|
# multi class evaluation
|
|
if len(preds.shape) == len(target.shape) + 1:
|
|
if pos_label is not None:
|
|
rank_zero_warn('Argument `pos_label` should be `None` when running'
|
|
f'multiclass precision recall curve. Got {pos_label}')
|
|
if num_classes != preds.shape[1]:
|
|
raise ValueError(f'Argument `num_classes` was set to {num_classes} in'
|
|
f'metric `precision_recall_curve` but detected {preds.shape[1]}'
|
|
'number of classes from predictions')
|
|
preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1)
|
|
target = target.flatten()
|
|
|
|
return preds, target, num_classes, pos_label
|
|
|
|
|
|
def _precision_recall_curve_compute(
|
|
preds: torch.Tensor,
|
|
target: torch.Tensor,
|
|
num_classes: int,
|
|
pos_label: int,
|
|
sample_weights: Optional[Sequence] = None,
|
|
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]:
|
|
|
|
if num_classes == 1:
|
|
fps, tps, thresholds = _binary_clf_curve(
|
|
preds=preds,
|
|
target=target,
|
|
sample_weights=sample_weights,
|
|
pos_label=pos_label
|
|
)
|
|
|
|
precision = tps / (tps + fps)
|
|
recall = tps / tps[-1]
|
|
|
|
# stop when full recall attained
|
|
# and reverse the outputs so recall is decreasing
|
|
last_ind = torch.where(tps == tps[-1])[0][0]
|
|
sl = slice(0, last_ind.item() + 1)
|
|
|
|
# need to call reversed explicitly, since including that to slice would
|
|
# introduce negative strides that are not yet supported in pytorch
|
|
precision = torch.cat([reversed(precision[sl]),
|
|
torch.ones(1, dtype=precision.dtype,
|
|
device=precision.device)])
|
|
|
|
recall = torch.cat([reversed(recall[sl]),
|
|
torch.zeros(1, dtype=recall.dtype,
|
|
device=recall.device)])
|
|
|
|
thresholds = reversed(thresholds[sl]).clone()
|
|
|
|
return precision, recall, thresholds
|
|
|
|
# Recursively call per class
|
|
precision, recall, thresholds = [], [], []
|
|
for c in range(num_classes):
|
|
preds_c = preds[:, c]
|
|
res = precision_recall_curve(
|
|
preds=preds_c,
|
|
target=target,
|
|
num_classes=1,
|
|
pos_label=c,
|
|
sample_weights=sample_weights,
|
|
)
|
|
precision.append(res[0])
|
|
recall.append(res[1])
|
|
thresholds.append(res[2])
|
|
|
|
return precision, recall, thresholds
|
|
|
|
|
|
def precision_recall_curve(
|
|
preds: torch.Tensor,
|
|
target: torch.Tensor,
|
|
num_classes: Optional[int] = None,
|
|
pos_label: Optional[int] = None,
|
|
sample_weights: Optional[Sequence] = None,
|
|
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|
Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]:
|
|
"""
|
|
Computes precision-recall pairs for different thresholds.
|
|
|
|
Args:
|
|
num_classes: integer with number of classes. Not nessesary to provide
|
|
for binary problems.
|
|
pos_label: integer determining the positive class. Default is ``None``
|
|
which for binary problem is translate to 1. For multiclass problems
|
|
this argument should not be set as we iteratively change it in the
|
|
range [0,num_classes-1]
|
|
sample_weights: sample weights for each data point
|
|
|
|
Returns: 3-element tuple containing
|
|
|
|
precision:
|
|
tensor where element i is the precision of predictions with
|
|
score >= thresholds[i] and the last element is 1.
|
|
If multiclass, this is a list of such tensors, one for each class.
|
|
recall:
|
|
tensor where element i is the recall of predictions with
|
|
score >= thresholds[i] and the last element is 0.
|
|
If multiclass, this is a list of such tensors, one for each class.
|
|
thresholds:
|
|
Thresholds used for computing precision/recall scores
|
|
|
|
Example (binary case):
|
|
|
|
>>> pred = torch.tensor([0, 1, 2, 3])
|
|
>>> target = torch.tensor([0, 1, 1, 0])
|
|
>>> precision, recall, thresholds = precision_recall_curve(pred, target, pos_label=1)
|
|
>>> precision
|
|
tensor([0.6667, 0.5000, 0.0000, 1.0000])
|
|
>>> recall
|
|
tensor([1.0000, 0.5000, 0.0000, 0.0000])
|
|
>>> thresholds
|
|
tensor([1, 2, 3])
|
|
|
|
Example (multiclass case):
|
|
|
|
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
|
|
... [0.05, 0.75, 0.05, 0.05, 0.05],
|
|
... [0.05, 0.05, 0.75, 0.05, 0.05],
|
|
... [0.05, 0.05, 0.05, 0.75, 0.05]])
|
|
>>> target = torch.tensor([0, 1, 3, 2])
|
|
>>> precision, recall, thresholds = precision_recall_curve(pred, target, num_classes=5)
|
|
>>> precision # doctest: +NORMALIZE_WHITESPACE
|
|
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]),
|
|
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
|
|
>>> recall
|
|
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
|
|
>>> thresholds
|
|
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
|
|
|
|
"""
|
|
preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target,
|
|
num_classes, pos_label)
|
|
return _precision_recall_curve_compute(preds, target, num_classes, pos_label, sample_weights)
|