lightning/pytorch_lightning/metrics/functional/precision_recall_curve.py

221 lines
8.7 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Sequence, Tuple, Union
import torch
import torch.nn.functional as F
from pytorch_lightning.utilities import rank_zero_warn
def _binary_clf_curve(
preds: torch.Tensor,
target: torch.Tensor,
sample_weights: Optional[Sequence] = None,
pos_label: int = 1.,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py
"""
if sample_weights is not None and not isinstance(sample_weights, torch.Tensor):
sample_weights = torch.tensor(sample_weights, device=preds.device, dtype=torch.float)
# remove class dimension if necessary
if preds.ndim > target.ndim:
preds = preds[:, 0]
desc_score_indices = torch.argsort(preds, descending=True)
preds = preds[desc_score_indices]
target = target[desc_score_indices]
if sample_weights is not None:
weight = sample_weights[desc_score_indices]
else:
weight = 1.
# pred typically has many tied values. Here we extract
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
distinct_value_indices = torch.where(preds[1:] - preds[:-1])[0]
threshold_idxs = F.pad(distinct_value_indices, (0, 1), value=target.size(0) - 1)
target = (target == pos_label).to(torch.long)
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]
if sample_weights is not None:
# express fps as a cumsum to ensure fps is increasing even in
# the presence of floating point errors
fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs]
else:
fps = 1 + threshold_idxs - tps
return fps, tps, preds[threshold_idxs]
def _precision_recall_curve_update(
preds: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1):
raise ValueError(
"preds and target must have same number of dimensions, or one additional dimension for preds"
)
# single class evaluation
if len(preds.shape) == len(target.shape):
num_classes = 1
if pos_label is None:
rank_zero_warn('`pos_label` automatically set 1.')
pos_label = 1
preds = preds.flatten()
target = target.flatten()
# multi class evaluation
if len(preds.shape) == len(target.shape) + 1:
if pos_label is not None:
rank_zero_warn('Argument `pos_label` should be `None` when running'
f'multiclass precision recall curve. Got {pos_label}')
if num_classes != preds.shape[1]:
raise ValueError(f'Argument `num_classes` was set to {num_classes} in'
f'metric `precision_recall_curve` but detected {preds.shape[1]}'
'number of classes from predictions')
preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1)
target = target.flatten()
return preds, target, num_classes, pos_label
def _precision_recall_curve_compute(
preds: torch.Tensor,
target: torch.Tensor,
num_classes: int,
pos_label: int,
sample_weights: Optional[Sequence] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]:
if num_classes == 1:
fps, tps, thresholds = _binary_clf_curve(
preds=preds,
target=target,
sample_weights=sample_weights,
pos_label=pos_label
)
precision = tps / (tps + fps)
recall = tps / tps[-1]
# stop when full recall attained
# and reverse the outputs so recall is decreasing
last_ind = torch.where(tps == tps[-1])[0][0]
sl = slice(0, last_ind.item() + 1)
# need to call reversed explicitly, since including that to slice would
# introduce negative strides that are not yet supported in pytorch
precision = torch.cat([reversed(precision[sl]),
torch.ones(1, dtype=precision.dtype,
device=precision.device)])
recall = torch.cat([reversed(recall[sl]),
torch.zeros(1, dtype=recall.dtype,
device=recall.device)])
thresholds = reversed(thresholds[sl]).clone()
return precision, recall, thresholds
# Recursively call per class
precision, recall, thresholds = [], [], []
for c in range(num_classes):
preds_c = preds[:, c]
res = precision_recall_curve(
preds=preds_c,
target=target,
num_classes=1,
pos_label=c,
sample_weights=sample_weights,
)
precision.append(res[0])
recall.append(res[1])
thresholds.append(res[2])
return precision, recall, thresholds
def precision_recall_curve(
preds: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
pos_label: Optional[int] = None,
sample_weights: Optional[Sequence] = None,
) -> Union[Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]]:
"""
Computes precision-recall pairs for different thresholds.
Args:
num_classes: integer with number of classes. Not nessesary to provide
for binary problems.
pos_label: integer determining the positive class. Default is ``None``
which for binary problem is translate to 1. For multiclass problems
this argument should not be set as we iteratively change it in the
range [0,num_classes-1]
sample_weights: sample weights for each data point
Returns: 3-element tuple containing
precision:
tensor where element i is the precision of predictions with
score >= thresholds[i] and the last element is 1.
If multiclass, this is a list of such tensors, one for each class.
recall:
tensor where element i is the recall of predictions with
score >= thresholds[i] and the last element is 0.
If multiclass, this is a list of such tensors, one for each class.
thresholds:
Thresholds used for computing precision/recall scores
Example (binary case):
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 0])
>>> precision, recall, thresholds = precision_recall_curve(pred, target, pos_label=1)
>>> precision
tensor([0.6667, 0.5000, 0.0000, 1.0000])
>>> recall
tensor([1.0000, 0.5000, 0.0000, 0.0000])
>>> thresholds
tensor([1, 2, 3])
Example (multiclass case):
>>> pred = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05],
... [0.05, 0.75, 0.05, 0.05, 0.05],
... [0.05, 0.05, 0.75, 0.05, 0.05],
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> precision, recall, thresholds = precision_recall_curve(pred, target, num_classes=5)
>>> precision # doctest: +NORMALIZE_WHITESPACE
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]),
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
>>> recall
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
>>> thresholds
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
"""
preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target,
num_classes, pos_label)
return _precision_recall_curve_compute(preds, target, num_classes, pos_label, sample_weights)