180 lines
7.3 KiB
Python
180 lines
7.3 KiB
Python
|
# Copyright The PyTorch Lightning team.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
from distutils.version import LooseVersion
|
||
|
from typing import Optional, Sequence, Tuple
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
|
||
|
from pytorch_lightning.metrics.functional.auc import auc
|
||
|
from pytorch_lightning.metrics.functional.roc import roc
|
||
|
from pytorch_lightning.utilities import LightningEnum
|
||
|
|
||
|
|
||
|
class AverageMethods(LightningEnum):
|
||
|
""" Type of averages """
|
||
|
MACRO = 'macro'
|
||
|
WEIGHTED = 'weighted'
|
||
|
NONE = None
|
||
|
|
||
|
|
||
|
def _auroc_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, str]:
|
||
|
# use _input_format_classification for validating the input and get the mode of data
|
||
|
_, _, mode = _input_format_classification(preds, target)
|
||
|
|
||
|
if mode == 'multi class multi dim':
|
||
|
n_classes = preds.shape[1]
|
||
|
preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1)
|
||
|
target = target.flatten()
|
||
|
if mode == 'multi-label' and preds.ndim > 2:
|
||
|
n_classes = preds.shape[1]
|
||
|
preds = preds.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1)
|
||
|
target = target.transpose(0, 1).reshape(n_classes, -1).transpose(0, 1)
|
||
|
|
||
|
return preds, target, mode
|
||
|
|
||
|
|
||
|
def _auroc_compute(
|
||
|
preds: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
mode: str,
|
||
|
num_classes: Optional[int] = None,
|
||
|
pos_label: Optional[int] = None,
|
||
|
average: Optional[str] = 'macro',
|
||
|
max_fpr: Optional[float] = None,
|
||
|
sample_weights: Optional[Sequence] = None,
|
||
|
) -> torch.Tensor:
|
||
|
# binary mode override num_classes
|
||
|
if mode == 'binary':
|
||
|
num_classes = 1
|
||
|
|
||
|
# check max_fpr parameter
|
||
|
if max_fpr is not None:
|
||
|
if (not isinstance(max_fpr, float) and 0 < max_fpr <= 1):
|
||
|
raise ValueError(f"`max_fpr` should be a float in range (0, 1], got: {max_fpr}")
|
||
|
|
||
|
if LooseVersion(torch.__version__) < LooseVersion('1.6.0'):
|
||
|
raise RuntimeError("`max_fpr` argument requires `torch.bucketize` which"
|
||
|
" is not available below PyTorch version 1.6")
|
||
|
|
||
|
# max_fpr parameter is only support for binary
|
||
|
if mode != 'binary':
|
||
|
raise ValueError(f"Partial AUC computation not available in "
|
||
|
f"multilabel/multiclass setting, 'max_fpr' must be"
|
||
|
f" set to `None`, received `{max_fpr}`.")
|
||
|
|
||
|
# calculate fpr, tpr
|
||
|
if mode == 'multi-label':
|
||
|
# for multilabel we iteratively evaluate roc in a binary fashion
|
||
|
output = [roc(preds[:, i], target[:, i], num_classes=1, pos_label=1, sample_weights=sample_weights)
|
||
|
for i in range(num_classes)]
|
||
|
fpr = [o[0] for o in output]
|
||
|
tpr = [o[1] for o in output]
|
||
|
else:
|
||
|
fpr, tpr, _ = roc(preds, target, num_classes, pos_label, sample_weights)
|
||
|
|
||
|
# calculate standard roc auc score
|
||
|
if max_fpr is None or max_fpr == 1:
|
||
|
if num_classes != 1:
|
||
|
# calculate auc scores per class
|
||
|
auc_scores = [auc(x, y) for x, y in zip(fpr, tpr)]
|
||
|
|
||
|
# calculate average
|
||
|
if average == AverageMethods.NONE:
|
||
|
return auc_scores
|
||
|
elif average == AverageMethods.MACRO:
|
||
|
return torch.mean(torch.stack(auc_scores))
|
||
|
elif average == AverageMethods.WEIGHTED:
|
||
|
if mode == 'multi-label':
|
||
|
support = torch.sum(target, dim=0)
|
||
|
else:
|
||
|
support = torch.bincount(target.flatten(), minlength=num_classes)
|
||
|
return torch.sum(torch.stack(auc_scores) * support / support.sum())
|
||
|
|
||
|
allowed_average = [e.value for e in AverageMethods]
|
||
|
raise ValueError(f"Argument `average` expected to be one of the following:"
|
||
|
f" {allowed_average} but got {average}")
|
||
|
|
||
|
return auc(fpr, tpr)
|
||
|
|
||
|
max_fpr = torch.tensor(max_fpr, device=fpr.device)
|
||
|
# Add a single point at max_fpr and interpolate its tpr value
|
||
|
stop = torch.bucketize(max_fpr, fpr, out_int32=True, right=True)
|
||
|
weight = (max_fpr - fpr[stop - 1]) / (fpr[stop] - fpr[stop - 1])
|
||
|
interp_tpr = torch.lerp(tpr[stop - 1], tpr[stop], weight)
|
||
|
tpr = torch.cat([tpr[:stop], interp_tpr.view(1)])
|
||
|
fpr = torch.cat([fpr[:stop], max_fpr.view(1)])
|
||
|
|
||
|
# Compute partial AUC
|
||
|
partial_auc = auc(fpr, tpr)
|
||
|
|
||
|
# McClish correction: standardize result to be 0.5 if non-discriminant
|
||
|
# and 1 if maximal
|
||
|
min_area = 0.5 * max_fpr ** 2
|
||
|
max_area = max_fpr
|
||
|
return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area))
|
||
|
|
||
|
|
||
|
def auroc(
|
||
|
preds: torch.Tensor,
|
||
|
target: torch.Tensor,
|
||
|
num_classes: Optional[int] = None,
|
||
|
pos_label: Optional[int] = None,
|
||
|
average: Optional[str] = 'macro',
|
||
|
max_fpr: Optional[float] = None,
|
||
|
sample_weights: Optional[Sequence] = None,
|
||
|
) -> torch.Tensor:
|
||
|
""" Compute `Area Under the Receiver Operating Characteristic Curve (ROC AUC)
|
||
|
<https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Further_interpretations>`_
|
||
|
|
||
|
Args:
|
||
|
preds: Predictions from model (probabilities)
|
||
|
target: Ground truth labels
|
||
|
num_classes: integer with number of classes. Not nessesary to provide
|
||
|
for binary problems.
|
||
|
pos_label: integer determining the positive class. Default is ``None``
|
||
|
which for binary problem is translate to 1. For multiclass problems
|
||
|
this argument should not be set as we iteratively change it in the
|
||
|
range [0,num_classes-1]
|
||
|
average:
|
||
|
- ``'macro'`` computes metric for each class and uniformly averages them
|
||
|
- ``'weighted'`` computes metric for each class and does a weighted-average,
|
||
|
where each class is weighted by their support (accounts for class imbalance)
|
||
|
- ``None`` computes and returns the metric per class
|
||
|
max_fpr:
|
||
|
If not ``None``, calculates standardized partial AUC over the
|
||
|
range [0, max_fpr]. Should be a float between 0 and 1.
|
||
|
sample_weight: sample weights for each data point
|
||
|
|
||
|
Example (binary case):
|
||
|
|
||
|
>>> preds = torch.tensor([0.13, 0.26, 0.08, 0.19, 0.34])
|
||
|
>>> target = torch.tensor([0, 0, 1, 1, 1])
|
||
|
>>> auroc(preds, target, pos_label=1)
|
||
|
tensor(0.5000)
|
||
|
|
||
|
Example (multiclass case):
|
||
|
|
||
|
>>> preds = torch.tensor([[0.90, 0.05, 0.05],
|
||
|
... [0.05, 0.90, 0.05],
|
||
|
... [0.05, 0.05, 0.90],
|
||
|
... [0.85, 0.05, 0.10],
|
||
|
... [0.10, 0.10, 0.80]])
|
||
|
>>> target = torch.tensor([0, 1, 1, 2, 2])
|
||
|
>>> auroc(preds, target, num_classes=3)
|
||
|
tensor(0.7778)
|
||
|
"""
|
||
|
preds, target, mode = _auroc_update(preds, target)
|
||
|
return _auroc_compute(preds, target, mode, num_classes, pos_label, average, max_fpr, sample_weights)
|