2020-10-13 11:18:07 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2021-03-15 11:18:43 +00:00
|
|
|
from typing import List, Optional, Tuple
|
2020-10-06 21:03:24 +00:00
|
|
|
|
2020-11-23 08:44:35 +00:00
|
|
|
import torch
|
2020-10-06 21:03:24 +00:00
|
|
|
|
2020-12-04 21:42:23 +00:00
|
|
|
from pytorch_lightning.utilities import rank_zero_warn
|
2020-10-06 21:03:24 +00:00
|
|
|
|
2020-10-10 16:31:00 +00:00
|
|
|
METRIC_EPS = 1e-6
|
|
|
|
|
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
def dim_zero_cat(x):
|
2020-12-07 16:49:35 +00:00
|
|
|
x = x if isinstance(x, (list, tuple)) else [x]
|
2020-10-06 21:03:24 +00:00
|
|
|
return torch.cat(x, dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
def dim_zero_sum(x):
|
|
|
|
return torch.sum(x, dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
def dim_zero_mean(x):
|
|
|
|
return torch.mean(x, dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
def _flatten(x):
|
|
|
|
return [item for sublist in x for item in sublist]
|
2020-10-10 16:31:00 +00:00
|
|
|
|
|
|
|
|
2020-10-21 22:05:59 +00:00
|
|
|
def _check_same_shape(pred: torch.Tensor, target: torch.Tensor):
|
|
|
|
""" Check that predictions and target have the same shape, else raise error """
|
|
|
|
if pred.shape != target.shape:
|
2020-12-07 16:49:35 +00:00
|
|
|
raise RuntimeError("Predictions and targets are expected to have the same shape")
|
2020-10-30 10:44:25 +00:00
|
|
|
|
|
|
|
|
2020-11-23 08:44:35 +00:00
|
|
|
def _input_format_classification_one_hot(
|
2021-02-01 08:24:07 +00:00
|
|
|
num_classes: int,
|
|
|
|
preds: torch.Tensor,
|
|
|
|
target: torch.Tensor,
|
|
|
|
threshold: float = 0.5,
|
|
|
|
multilabel: bool = False
|
2020-11-23 08:44:35 +00:00
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
2020-12-07 16:49:35 +00:00
|
|
|
"""Convert preds and target tensors into one hot spare label tensors
|
2020-11-23 08:44:35 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
num_classes: number of classes
|
|
|
|
preds: either tensor with labels, tensor with probabilities/logits or
|
|
|
|
multilabel tensor
|
|
|
|
target: tensor with ground true labels
|
|
|
|
threshold: float used for thresholding multilabel input
|
|
|
|
multilabel: boolean flag indicating if input is multilabel
|
|
|
|
|
2021-03-04 19:34:03 +00:00
|
|
|
Raises:
|
|
|
|
ValueError:
|
|
|
|
If ``preds`` and ``target`` don't have the same number of dimensions
|
|
|
|
or one additional dimension for ``preds``.
|
|
|
|
|
2020-11-23 08:44:35 +00:00
|
|
|
Returns:
|
|
|
|
preds: one hot tensor of shape [num_classes, -1] with predicted labels
|
|
|
|
target: one hot tensors of shape [num_classes, -1] with true labels
|
|
|
|
"""
|
2020-12-07 16:49:35 +00:00
|
|
|
if not (preds.ndim == target.ndim or preds.ndim == target.ndim + 1):
|
|
|
|
raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds")
|
2020-11-23 08:44:35 +00:00
|
|
|
|
2020-12-07 16:49:35 +00:00
|
|
|
if preds.ndim == target.ndim + 1:
|
2020-11-23 08:44:35 +00:00
|
|
|
# multi class probabilites
|
|
|
|
preds = torch.argmax(preds, dim=1)
|
|
|
|
|
2020-12-07 16:49:35 +00:00
|
|
|
if preds.ndim == target.ndim and preds.dtype in (torch.long, torch.int) and num_classes > 1 and not multilabel:
|
2020-11-23 08:44:35 +00:00
|
|
|
# multi-class
|
|
|
|
preds = to_onehot(preds, num_classes=num_classes)
|
|
|
|
target = to_onehot(target, num_classes=num_classes)
|
|
|
|
|
2020-12-07 16:49:35 +00:00
|
|
|
elif preds.ndim == target.ndim and preds.is_floating_point():
|
2020-11-23 08:44:35 +00:00
|
|
|
# binary or multilabel probablities
|
|
|
|
preds = (preds >= threshold).long()
|
|
|
|
|
|
|
|
# transpose class as first dim and reshape
|
2020-12-07 16:49:35 +00:00
|
|
|
if preds.ndim > 1:
|
2020-11-23 08:44:35 +00:00
|
|
|
preds = preds.transpose(1, 0)
|
|
|
|
target = target.transpose(1, 0)
|
|
|
|
|
|
|
|
return preds.reshape(num_classes, -1), target.reshape(num_classes, -1)
|
2020-12-04 21:42:23 +00:00
|
|
|
|
|
|
|
|
2021-03-15 11:18:43 +00:00
|
|
|
def get_group_indexes(idx: torch.Tensor) -> List[torch.Tensor]:
|
|
|
|
"""
|
|
|
|
Given an integer `torch.Tensor` `idx`, return a `torch.Tensor` of indexes for
|
|
|
|
each different value in `idx`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
idx: a `torch.Tensor` of integers
|
|
|
|
|
|
|
|
Return:
|
|
|
|
A list of integer `torch.Tensor`s
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
|
|
>>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1])
|
|
|
|
>>> groups = get_group_indexes(indexes)
|
|
|
|
>>> groups
|
|
|
|
[tensor([0, 1, 2]), tensor([3, 4, 5, 6])]
|
|
|
|
"""
|
|
|
|
|
|
|
|
indexes = dict()
|
|
|
|
for i, _id in enumerate(idx):
|
|
|
|
_id = _id.item()
|
|
|
|
if _id in indexes:
|
|
|
|
indexes[_id] += [i]
|
|
|
|
else:
|
|
|
|
indexes[_id] = [i]
|
|
|
|
return [torch.tensor(x, dtype=torch.int64) for x in indexes.values()]
|
|
|
|
|
|
|
|
|
2020-12-04 21:42:23 +00:00
|
|
|
def to_onehot(
|
2020-12-07 16:49:35 +00:00
|
|
|
label_tensor: torch.Tensor,
|
|
|
|
num_classes: Optional[int] = None,
|
2020-12-04 21:42:23 +00:00
|
|
|
) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Converts a dense label tensor to one-hot format
|
|
|
|
|
|
|
|
Args:
|
2020-12-07 16:49:35 +00:00
|
|
|
label_tensor: dense label tensor, with shape [N, d1, d2, ...]
|
2020-12-04 21:42:23 +00:00
|
|
|
num_classes: number of classes C
|
|
|
|
|
2021-02-22 08:50:59 +00:00
|
|
|
Returns:
|
2020-12-04 21:42:23 +00:00
|
|
|
A sparse label tensor with shape [N, C, d1, d2, ...]
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
2021-02-22 08:50:59 +00:00
|
|
|
>>> from pytorch_lightning.metrics.utils import to_onehot
|
2020-12-04 21:42:23 +00:00
|
|
|
>>> x = torch.tensor([1, 2, 3])
|
|
|
|
>>> to_onehot(x)
|
|
|
|
tensor([[0, 1, 0, 0],
|
|
|
|
[0, 0, 1, 0],
|
|
|
|
[0, 0, 0, 1]])
|
|
|
|
"""
|
|
|
|
if num_classes is None:
|
2020-12-07 16:49:35 +00:00
|
|
|
num_classes = int(label_tensor.max().detach().item() + 1)
|
|
|
|
|
|
|
|
tensor_onehot = torch.zeros(
|
|
|
|
label_tensor.shape[0],
|
|
|
|
num_classes,
|
|
|
|
*label_tensor.shape[1:],
|
|
|
|
dtype=label_tensor.dtype,
|
|
|
|
device=label_tensor.device,
|
|
|
|
)
|
|
|
|
index = label_tensor.long().unsqueeze(1).expand_as(tensor_onehot)
|
2020-12-04 21:42:23 +00:00
|
|
|
return tensor_onehot.scatter_(1, index, 1.0)
|
|
|
|
|
|
|
|
|
2020-12-07 16:49:35 +00:00
|
|
|
def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Convert a probability tensor to binary by selecting top-k highest entries.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
prob_tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the
|
|
|
|
position defined by the ``dim`` argument
|
|
|
|
topk: number of highest entries to turn into 1s
|
|
|
|
dim: dimension on which to compare entries
|
|
|
|
|
2021-02-22 08:50:59 +00:00
|
|
|
Returns:
|
2020-12-07 16:49:35 +00:00
|
|
|
A binary tensor of the same shape as the input tensor of type torch.int32
|
|
|
|
|
|
|
|
Example:
|
2021-02-22 08:50:59 +00:00
|
|
|
|
|
|
|
>>> from pytorch_lightning.metrics.utils import select_topk
|
2020-12-07 16:49:35 +00:00
|
|
|
>>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]])
|
|
|
|
>>> select_topk(x, topk=2)
|
|
|
|
tensor([[0, 1, 1],
|
|
|
|
[1, 1, 0]], dtype=torch.int32)
|
|
|
|
"""
|
|
|
|
zeros = torch.zeros_like(prob_tensor)
|
2020-12-21 15:42:51 +00:00
|
|
|
topk_tensor = zeros.scatter(dim, prob_tensor.topk(k=topk, dim=dim).indices, 1.0)
|
2020-12-07 16:49:35 +00:00
|
|
|
return topk_tensor.int()
|
|
|
|
|
|
|
|
|
|
|
|
def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
|
2020-12-04 21:42:23 +00:00
|
|
|
"""
|
|
|
|
Converts a tensor of probabilities to a dense label tensor
|
|
|
|
|
|
|
|
Args:
|
|
|
|
tensor: probabilities to get the categorical label [N, d1, d2, ...]
|
|
|
|
argmax_dim: dimension to apply
|
|
|
|
|
|
|
|
Return:
|
|
|
|
A tensor with categorical labels [N, d2, ...]
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
2021-02-22 08:50:59 +00:00
|
|
|
>>> from pytorch_lightning.metrics.utils import to_categorical
|
2020-12-04 21:42:23 +00:00
|
|
|
>>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]])
|
|
|
|
>>> to_categorical(x)
|
|
|
|
tensor([1, 0])
|
|
|
|
"""
|
|
|
|
return torch.argmax(tensor, dim=argmax_dim)
|
|
|
|
|
|
|
|
|
|
|
|
def get_num_classes(
|
2020-12-07 16:49:35 +00:00
|
|
|
pred: torch.Tensor,
|
|
|
|
target: torch.Tensor,
|
|
|
|
num_classes: Optional[int] = None,
|
2020-12-04 21:42:23 +00:00
|
|
|
) -> int:
|
|
|
|
"""
|
|
|
|
Calculates the number of classes for a given prediction and target tensor.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
pred: predicted values
|
|
|
|
target: true labels
|
|
|
|
num_classes: number of classes if known
|
|
|
|
|
|
|
|
Return:
|
|
|
|
An integer that represents the number of classes.
|
|
|
|
"""
|
|
|
|
num_target_classes = int(target.max().detach().item() + 1)
|
|
|
|
num_pred_classes = int(pred.max().detach().item() + 1)
|
|
|
|
num_all_classes = max(num_target_classes, num_pred_classes)
|
|
|
|
|
|
|
|
if num_classes is None:
|
|
|
|
num_classes = num_all_classes
|
|
|
|
elif num_classes != num_all_classes:
|
2020-12-07 16:49:35 +00:00
|
|
|
rank_zero_warn(
|
|
|
|
f"You have set {num_classes} number of classes which is"
|
|
|
|
f" different from predicted ({num_pred_classes}) and"
|
|
|
|
f" target ({num_target_classes}) number of classes",
|
|
|
|
RuntimeWarning,
|
|
|
|
)
|
2020-12-04 21:42:23 +00:00
|
|
|
return num_classes
|
|
|
|
|
|
|
|
|
|
|
|
def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Reduces a given tensor by a given reduction method
|
|
|
|
|
|
|
|
Args:
|
|
|
|
to_reduce : the tensor, which shall be reduced
|
|
|
|
reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum')
|
|
|
|
|
|
|
|
Return:
|
|
|
|
reduced Tensor
|
|
|
|
|
|
|
|
Raise:
|
|
|
|
ValueError if an invalid reduction parameter was given
|
|
|
|
"""
|
2020-12-07 16:49:35 +00:00
|
|
|
if reduction == "elementwise_mean":
|
2020-12-04 21:42:23 +00:00
|
|
|
return torch.mean(to_reduce)
|
2020-12-07 16:49:35 +00:00
|
|
|
if reduction == "none":
|
2020-12-04 21:42:23 +00:00
|
|
|
return to_reduce
|
2020-12-07 16:49:35 +00:00
|
|
|
if reduction == "sum":
|
2020-12-04 21:42:23 +00:00
|
|
|
return torch.sum(to_reduce)
|
2020-12-07 16:49:35 +00:00
|
|
|
raise ValueError("Reduction parameter unknown.")
|
2020-12-04 21:42:23 +00:00
|
|
|
|
|
|
|
|
2020-12-07 16:49:35 +00:00
|
|
|
def class_reduce(
|
|
|
|
num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none"
|
|
|
|
) -> torch.Tensor:
|
2020-12-04 21:42:23 +00:00
|
|
|
"""
|
|
|
|
Function used to reduce classification metrics of the form `num / denom * weights`.
|
|
|
|
For example for calculating standard accuracy the num would be number of
|
|
|
|
true positives per class, denom would be the support per class, and weights
|
|
|
|
would be a tensor of 1s
|
|
|
|
|
|
|
|
Args:
|
|
|
|
num: numerator tensor
|
2021-01-07 13:01:52 +00:00
|
|
|
denom: denominator tensor
|
2020-12-04 21:42:23 +00:00
|
|
|
weights: weights for each class
|
|
|
|
class_reduction: reduction method for multiclass problems
|
|
|
|
|
|
|
|
- ``'micro'``: calculate metrics globally (default)
|
|
|
|
- ``'macro'``: calculate metrics for each label, and find their unweighted mean.
|
|
|
|
- ``'weighted'``: calculate metrics for each label, and find their weighted mean.
|
|
|
|
- ``'none'`` or ``None``: returns calculated metric per class
|
|
|
|
|
2021-03-04 19:34:03 +00:00
|
|
|
Raises:
|
|
|
|
ValueError:
|
|
|
|
If ``class_reduction`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"`` or ``None``.
|
2020-12-04 21:42:23 +00:00
|
|
|
"""
|
2020-12-07 16:49:35 +00:00
|
|
|
valid_reduction = ("micro", "macro", "weighted", "none", None)
|
|
|
|
if class_reduction == "micro":
|
2020-12-04 21:42:23 +00:00
|
|
|
fraction = torch.sum(num) / torch.sum(denom)
|
|
|
|
else:
|
|
|
|
fraction = num / denom
|
|
|
|
|
|
|
|
# We need to take care of instances where the denom can be 0
|
|
|
|
# for some (or all) classes which will produce nans
|
|
|
|
fraction[fraction != fraction] = 0
|
|
|
|
|
2020-12-07 16:49:35 +00:00
|
|
|
if class_reduction == "micro":
|
2020-12-04 21:42:23 +00:00
|
|
|
return fraction
|
2020-12-07 16:49:35 +00:00
|
|
|
elif class_reduction == "macro":
|
2020-12-04 21:42:23 +00:00
|
|
|
return torch.mean(fraction)
|
2020-12-07 16:49:35 +00:00
|
|
|
elif class_reduction == "weighted":
|
2020-12-04 21:42:23 +00:00
|
|
|
return torch.sum(fraction * (weights.float() / torch.sum(weights)))
|
2020-12-07 16:49:35 +00:00
|
|
|
elif class_reduction == "none" or class_reduction is None:
|
2020-12-04 21:42:23 +00:00
|
|
|
return fraction
|
|
|
|
|
2020-12-07 16:49:35 +00:00
|
|
|
raise ValueError(
|
2021-02-01 08:24:07 +00:00
|
|
|
f"Reduction parameter {class_reduction} unknown."
|
|
|
|
f" Choose between one of these: {valid_reduction}"
|
2020-12-07 16:49:35 +00:00
|
|
|
)
|
2021-01-27 13:16:54 +00:00
|
|
|
|
|
|
|
|
|
|
|
def _stable_1d_sort(x: torch, N: int = 2049):
|
|
|
|
"""
|
|
|
|
Stable sort of 1d tensors. Pytorch defaults to a stable sorting algorithm
|
|
|
|
if number of elements are larger than 2048. This function pads the tensors,
|
|
|
|
makes the sort and returns the sorted array (with the padding removed)
|
|
|
|
See this discussion: https://discuss.pytorch.org/t/is-torch-sort-stable/20714
|
2021-03-04 19:34:03 +00:00
|
|
|
|
|
|
|
Raises:
|
|
|
|
ValueError:
|
|
|
|
If dim of ``x`` is greater than 1 since stable sort works with only 1d tensors.
|
2021-01-27 13:16:54 +00:00
|
|
|
"""
|
|
|
|
if x.ndim > 1:
|
|
|
|
raise ValueError('Stable sort only works on 1d tensors')
|
|
|
|
n = x.numel()
|
|
|
|
if N - n > 0:
|
|
|
|
x_max = x.max()
|
2021-03-04 19:18:57 +00:00
|
|
|
x = torch.cat([x, (x_max + 1) * torch.ones(N - n, dtype=x.dtype, device=x.device)], 0)
|
|
|
|
x_sort = x.sort()
|
|
|
|
i = min(N, n)
|
|
|
|
return x_sort.values[:i], x_sort.indices[:i]
|