lightning/pytorch_lightning/metrics/utils.py

289 lines
9.3 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
from pytorch_lightning.utilities import rank_zero_warn
METRIC_EPS = 1e-6
def dim_zero_cat(x):
x = x if isinstance(x, (list, tuple)) else [x]
return torch.cat(x, dim=0)
def dim_zero_sum(x):
return torch.sum(x, dim=0)
def dim_zero_mean(x):
return torch.mean(x, dim=0)
def _flatten(x):
return [item for sublist in x for item in sublist]
def _check_same_shape(pred: torch.Tensor, target: torch.Tensor):
""" Check that predictions and target have the same shape, else raise error """
if pred.shape != target.shape:
raise RuntimeError("Predictions and targets are expected to have the same shape")
def _input_format_classification_one_hot(
num_classes: int,
preds: torch.Tensor,
target: torch.Tensor,
threshold: float = 0.5,
multilabel: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Convert preds and target tensors into one hot spare label tensors
Args:
num_classes: number of classes
preds: either tensor with labels, tensor with probabilities/logits or
multilabel tensor
target: tensor with ground true labels
threshold: float used for thresholding multilabel input
multilabel: boolean flag indicating if input is multilabel
Returns:
preds: one hot tensor of shape [num_classes, -1] with predicted labels
target: one hot tensors of shape [num_classes, -1] with true labels
"""
if not (preds.ndim == target.ndim or preds.ndim == target.ndim + 1):
raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds")
if preds.ndim == target.ndim + 1:
# multi class probabilites
preds = torch.argmax(preds, dim=1)
if preds.ndim == target.ndim and preds.dtype in (torch.long, torch.int) and num_classes > 1 and not multilabel:
# multi-class
preds = to_onehot(preds, num_classes=num_classes)
target = to_onehot(target, num_classes=num_classes)
elif preds.ndim == target.ndim and preds.is_floating_point():
# binary or multilabel probablities
preds = (preds >= threshold).long()
# transpose class as first dim and reshape
if preds.ndim > 1:
preds = preds.transpose(1, 0)
target = target.transpose(1, 0)
return preds.reshape(num_classes, -1), target.reshape(num_classes, -1)
def to_onehot(
label_tensor: torch.Tensor,
num_classes: Optional[int] = None,
) -> torch.Tensor:
"""
Converts a dense label tensor to one-hot format
Args:
label_tensor: dense label tensor, with shape [N, d1, d2, ...]
num_classes: number of classes C
Output:
A sparse label tensor with shape [N, C, d1, d2, ...]
Example:
>>> x = torch.tensor([1, 2, 3])
>>> to_onehot(x)
tensor([[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]])
"""
if num_classes is None:
num_classes = int(label_tensor.max().detach().item() + 1)
tensor_onehot = torch.zeros(
label_tensor.shape[0],
num_classes,
*label_tensor.shape[1:],
dtype=label_tensor.dtype,
device=label_tensor.device,
)
index = label_tensor.long().unsqueeze(1).expand_as(tensor_onehot)
return tensor_onehot.scatter_(1, index, 1.0)
def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor:
"""
Convert a probability tensor to binary by selecting top-k highest entries.
Args:
prob_tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the
position defined by the ``dim`` argument
topk: number of highest entries to turn into 1s
dim: dimension on which to compare entries
Output:
A binary tensor of the same shape as the input tensor of type torch.int32
Example:
>>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]])
>>> select_topk(x, topk=2)
tensor([[0, 1, 1],
[1, 1, 0]], dtype=torch.int32)
"""
zeros = torch.zeros_like(prob_tensor)
topk_tensor = zeros.scatter(dim, prob_tensor.topk(k=topk, dim=dim).indices, 1.0)
return topk_tensor.int()
def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
"""
Converts a tensor of probabilities to a dense label tensor
Args:
tensor: probabilities to get the categorical label [N, d1, d2, ...]
argmax_dim: dimension to apply
Return:
A tensor with categorical labels [N, d2, ...]
Example:
>>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]])
>>> to_categorical(x)
tensor([1, 0])
"""
return torch.argmax(tensor, dim=argmax_dim)
def get_num_classes(
pred: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
) -> int:
"""
Calculates the number of classes for a given prediction and target tensor.
Args:
pred: predicted values
target: true labels
num_classes: number of classes if known
Return:
An integer that represents the number of classes.
"""
num_target_classes = int(target.max().detach().item() + 1)
num_pred_classes = int(pred.max().detach().item() + 1)
num_all_classes = max(num_target_classes, num_pred_classes)
if num_classes is None:
num_classes = num_all_classes
elif num_classes != num_all_classes:
rank_zero_warn(
f"You have set {num_classes} number of classes which is"
f" different from predicted ({num_pred_classes}) and"
f" target ({num_target_classes}) number of classes",
RuntimeWarning,
)
return num_classes
def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
"""
Reduces a given tensor by a given reduction method
Args:
to_reduce : the tensor, which shall be reduced
reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum')
Return:
reduced Tensor
Raise:
ValueError if an invalid reduction parameter was given
"""
if reduction == "elementwise_mean":
return torch.mean(to_reduce)
if reduction == "none":
return to_reduce
if reduction == "sum":
return torch.sum(to_reduce)
raise ValueError("Reduction parameter unknown.")
def class_reduce(
num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none"
) -> torch.Tensor:
"""
Function used to reduce classification metrics of the form `num / denom * weights`.
For example for calculating standard accuracy the num would be number of
true positives per class, denom would be the support per class, and weights
would be a tensor of 1s
Args:
num: numerator tensor
denom: denominator tensor
weights: weights for each class
class_reduction: reduction method for multiclass problems
- ``'micro'``: calculate metrics globally (default)
- ``'macro'``: calculate metrics for each label, and find their unweighted mean.
- ``'weighted'``: calculate metrics for each label, and find their weighted mean.
- ``'none'`` or ``None``: returns calculated metric per class
"""
valid_reduction = ("micro", "macro", "weighted", "none", None)
if class_reduction == "micro":
fraction = torch.sum(num) / torch.sum(denom)
else:
fraction = num / denom
# We need to take care of instances where the denom can be 0
# for some (or all) classes which will produce nans
fraction[fraction != fraction] = 0
if class_reduction == "micro":
return fraction
elif class_reduction == "macro":
return torch.mean(fraction)
elif class_reduction == "weighted":
return torch.sum(fraction * (weights.float() / torch.sum(weights)))
elif class_reduction == "none" or class_reduction is None:
return fraction
raise ValueError(
f"Reduction parameter {class_reduction} unknown."
f" Choose between one of these: {valid_reduction}"
)
def _stable_1d_sort(x: torch, N: int = 2049):
"""
Stable sort of 1d tensors. Pytorch defaults to a stable sorting algorithm
if number of elements are larger than 2048. This function pads the tensors,
makes the sort and returns the sorted array (with the padding removed)
See this discussion: https://discuss.pytorch.org/t/is-torch-sort-stable/20714
"""
if x.ndim > 1:
raise ValueError('Stable sort only works on 1d tensors')
n = x.numel()
if N - n > 0:
x_max = x.max()
x_pad = torch.cat([x, (x_max + 1) * torch.ones(2049 - n, dtype=x.dtype, device=x.device)], 0)
x_sort = x_pad.sort()
return x_sort.values[:n], x_sort.indices[:n]