Native torch metrics (#1488)

* Create metric.py

* Create utils.py

* Create __init__.py

* Create __init__.py

* Create __init__.py

* add tests for metric utils

* add tests for metric utils

* add docstrings for metrics utils

* add docstrings for metrics utils

* add function to recursively apply other function to collection

* add function to recursively apply other function to collection

* add tests for this function

* add tests for this function

* add tests for this function

* update test

* update test

* Update pytorch_lightning/metrics/metric.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* update metric name

* remove example docs

* fix tests

* fix tests

* add metric tests

* fix to tensor conversion

* fix to tensor conversion

* fix apply to collection

* fix apply to collection

* Update pytorch_lightning/metrics/metric.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* remove tests from init

* remove tests from init

* add missing type annotations

* rename utils to convertors

* rename utils to convertors

* rename utils to convertors

* rename utils to convertors

* Update pytorch_lightning/metrics/convertors.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/metrics/convertors.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/metrics/convertors.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/metrics/convertors.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/metrics/convertors.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/metrics/convertors.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/metrics/metric.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/utilities/test_apply_to_collection.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/utilities/test_apply_to_collection.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/utilities/test_apply_to_collection.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/utilities/test_apply_to_collection.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/metrics/convertors.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/metrics/convertors.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Apply suggestions from code review

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Apply suggestions from code review

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Apply suggestions from code review

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* add doctest example

* rename file and fix imports

* rename file and fix imports

* added parametrized test

* added parametrized test

* replace lambda with inlined function

* rename apply_to_collection to apply_func

* rename apply_to_collection to apply_func

* rename apply_to_collection to apply_func

* Separated class description from init args

* Apply suggestions from code review

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* adjust random values

* suppress output when seeding

* remove gpu from doctest

* Add requested changes and add ellipsis for doctest

* Add requested changes and add ellipsis for doctest

* Add requested changes and add ellipsis for doctest

* forgot to push these files...

* forgot to push these files...

* forgot to push these files...

* add explicit check for dtype to convert to

* add explicit check for dtype to convert to

* fix ddp tests

* fix ddp tests

* fix ddp tests

* remove explicit ddp destruction

* remove explicit ddp destruction

* New metric classes (#1326)

* Create metrics package

* Create metric.py

* Create utils.py

* Create __init__.py

* add tests for metric utils

* add docstrings for metrics utils

* add function to recursively apply other function to collection

* add tests for this function

* update test

* Update pytorch_lightning/metrics/metric.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* update metric name

* remove example docs

* fix tests

* add metric tests

* fix to tensor conversion

* fix apply to collection

* Update CHANGELOG.md

* Update pytorch_lightning/metrics/metric.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* remove tests from init

* add missing type annotations

* rename utils to convertors

* Create metrics.rst

* Update index.rst

* Update index.rst

* Update pytorch_lightning/metrics/convertors.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/metrics/convertors.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/metrics/convertors.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/metrics/metric.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/utilities/test_apply_to_collection.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/utilities/test_apply_to_collection.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/metrics/convertors.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Apply suggestions from code review

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* add doctest example

* rename file and fix imports

* added parametrized test

* replace lambda with inlined function

* rename apply_to_collection to apply_func

* Separated class description from init args

* Apply suggestions from code review

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* adjust random values

* suppress output when seeding

* remove gpu from doctest

* Add requested changes and add ellipsis for doctest

* forgot to push these files...

* add explicit check for dtype to convert to

* fix ddp tests

* remove explicit ddp destruction

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* add function to reduce tensors (similar to reduction in torch.nn)

* add functionals of reduction metrics

* add functionals of reduction metrics

* add more metrics

* pep8 fixes

* rename

* rename

* add reduction tests

* add first classification tests

* bugfixes

* bugfixes

* add more unit tests

* fix roc score metric

* fix tests

* solve tests

* fix docs

* Update CHANGELOG.md

* remove binaries

* solve changes from rebase

* add eos

* test auc independently

* fix formatting

* docs

* docs

* chlog

* move

* function descriptions

* Add documentation to native metrics (#2144)

* add docs

* add docs

* Apply suggestions from code review

* formatting

* add docs

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka <jirka@pytorchlightning.ai>

* Rename tests/metrics/test_classification.py to tests/metrics/functional/test_classification.py

* Rename tests/metrics/test_reduction.py to tests/metrics/functional/test_reduction.py

* Add module interface for classification metrics

* add basic tests for classification metrics' module interface

* pep8

* add additional converters

* add additional base class

* change baseclass for some metrics

* update classification tests

* update converter tests

* update metric tests

* Apply suggestions from code review

* tests-params

* tests-params

* imports

* pep8

* tests-params

* formatting

* fix test_metrics

* typo

* formatting

* fix dice tests

* fix decorator order

* fix tests

* seed

* dice test

* formatting

* try freeze test

* formatting

* fix tests

* try spawn

* formatting

* fix

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
Co-authored-by: Xavier Sumba <c.uent@hotmail.com>
Co-authored-by: Jirka <jirka@pytorchlightning.ai>
Co-authored-by: Nicki Skafte <nugginea@gmail.com>
This commit is contained in:
Justus Schock 2020-06-13 14:47:25 +02:00 committed by GitHub
parent 9df2b2090d
commit 3436d00230
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 2249 additions and 221 deletions

View File

@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126))
- Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
- Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
- Added Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488))
- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723))
- Allow dataloaders without sampler field present ([#1907](https://github.com/PyTorchLightning/pytorch-lightning/pull/1907))
- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` [(#1908)](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908)

View File

@ -23,8 +23,8 @@ inputs to and outputs from numpy as well as automated ddp syncing.
"""
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
from pytorch_lightning.metrics.sklearn import (
SklearnMetric, Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric

View File

@ -0,0 +1,652 @@
from typing import Any, Optional, Sequence, Tuple
import torch
from pytorch_lightning.metrics.functional.classification import (
accuracy,
confusion_matrix,
precision_recall_curve,
precision,
recall,
average_precision,
auroc,
fbeta_score,
f1_score,
roc,
multiclass_roc,
multiclass_precision_recall_curve,
dice_score
)
from pytorch_lightning.metrics.metric import TensorMetric, TensorCollectionMetric
__all__ = [
'Accuracy',
'ConfusionMatrix',
'PrecisionRecall',
'Precision',
'Recall',
'AveragePrecision',
'AUROC',
'FBeta',
'F1',
'ROC',
'MulticlassROC',
'MulticlassPrecisionRecall',
'DiceCoefficient'
]
class Accuracy(TensorMetric):
"""
Computes the accuracy classification score
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='accuracy',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.num_classes = num_classes
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the classification score.
"""
return accuracy(pred=pred, target=target,
num_classes=self.num_classes, reduction=self.reduction)
class ConfusionMatrix(TensorMetric):
"""
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
in group i that were predicted in group j.
"""
def __init__(
self,
normalize: bool = False,
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
normalize: whether to compute a normalized confusion matrix
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='confusion_matrix',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.normalize = normalize
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the confusion matrix.
"""
return confusion_matrix(pred=pred, target=target,
normalize=self.normalize)
class PrecisionRecall(TensorCollectionMetric):
"""
Computes the precision recall curve
"""
def __init__(
self,
pos_label: int = 1,
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
pos_label: positive label indicator
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='precision_recall_curve',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.pos_label = pos_label
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
sample_weight: the weights per sample
Return:
torch.Tensor: precision values
torch.Tensor: recall values
torch.Tensor: threshold values
"""
return precision_recall_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=self.pos_label)
class Precision(TensorMetric):
"""
Computes the precision score
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='precision',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.num_classes = num_classes
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the classification score.
"""
return precision(pred=pred, target=target,
num_classes=self.num_classes,
reduction=self.reduction)
class Recall(TensorMetric):
"""
Computes the recall score
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='recall',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.num_classes = num_classes
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the classification score.
"""
return recall(pred=pred,
target=target,
num_classes=self.num_classes,
reduction=self.reduction)
class AveragePrecision(TensorMetric):
"""
Computes the average precision score
"""
def __init__(
self,
pos_label: int = 1,
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
pos_label: positive label indicator
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='AP',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.pos_label = pos_label
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None
) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
sample_weight: the weights per sample
Return:
torch.Tensor: classification score
"""
return average_precision(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=self.pos_label)
class AUROC(TensorMetric):
"""
Computes the area under curve (AUC) of the receiver operator characteristic (ROC)
"""
def __init__(
self,
pos_label: int = 1,
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
pos_label: positive label indicator
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='auroc',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.pos_label = pos_label
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None
) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
sample_weight: the weights per sample
Return:
torch.Tensor: classification score
"""
return auroc(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=self.pos_label)
class FBeta(TensorMetric):
"""Computes the FBeta Score"""
def __init__(
self,
beta: float,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
beta: determines the weight of recall in the combined score.
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='fbeta',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.beta = beta
self.num_classes = num_classes
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
Return:
torch.Tensor: classification score
"""
return fbeta_score(pred=pred, target=target,
beta=self.beta, num_classes=self.num_classes,
reduction=self.reduction)
class F1(TensorMetric):
"""Computes the F1 score"""
def __init__(
self,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='f1',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.num_classes = num_classes
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
Return:
torch.Tensor: classification score
"""
return f1_score(pred=pred, target=target,
num_classes=self.num_classes,
reduction=self.reduction)
class ROC(TensorCollectionMetric):
"""
Computes the Receiver Operator Characteristic (ROC)
"""
def __init__(
self,
pos_label: int = 1,
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
pos_label: positive label indicator
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='roc',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.pos_label = pos_label
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
sample_weight: the weights per sample
Return:
torch.Tensor: false positive rate
torch.Tensor: true positive rate
torch.Tensor: thresholds
"""
return roc(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=self.pos_label)
class MulticlassROC(TensorCollectionMetric):
"""
Computes the multiclass ROC
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='multiclass_roc',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.num_classes = num_classes
def forward(
self, pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
sample_weight: Weights for each sample defining the sample's impact on the score
Return:
tuple: A tuple consisting of one tuple per class,
holding false positive rate, true positive rate and thresholds
"""
return multiclass_roc(pred=pred,
target=target,
sample_weight=sample_weight,
num_classes=self.num_classes)
class MulticlassPrecisionRecall(TensorCollectionMetric):
"""Computes the multiclass PR Curve"""
def __init__(
self,
num_classes: Optional[int] = None,
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='multiclass_precision_recall_curve',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.num_classes = num_classes
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
sample_weight: Weights for each sample defining the sample's impact on the score
Return:
tuple: A tuple consisting of one tuple per class,
holding precision, recall and thresholds
"""
return multiclass_precision_recall_curve(pred=pred,
target=target,
sample_weight=sample_weight,
num_classes=self.num_classes)
class DiceCoefficient(TensorMetric):
"""Computes the dice coefficient"""
def __init__(
self,
include_background: bool = False,
nan_score: float = 0.0, no_fg_score: float = 0.0,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
):
"""
Args:
include_background: whether to also compute dice for the background
nan_score: score to return, if a NaN occurs during computation (denom zero)
no_fg_score: score to return, if no foreground pixel was found in target
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='dice',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.include_background = include_background
self.nan_score = nan_score
self.no_fg_score = no_fg_score
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
Return:
torch.Tensor: the calculated dice coefficient
"""
return dice_score(pred=pred,
target=target,
bg=self.include_background,
nan_score=self.nan_score,
no_fg_score=self.no_fg_score,
reduction=self.reduction)

View File

@ -4,7 +4,6 @@ conversion to/from :class:`numpy.ndarray` and :class:`torch.Tensor` as well as u
sync tensors between different processes in a DDP scenario, when needed.
"""
import sys
import numbers
from typing import Union, Any, Callable, Optional
@ -18,12 +17,13 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection
def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
"""
Decorator function to apply a function to all inputs of a function.
Args:
func_to_apply: the function to apply to the inputs
*dec_args: positional arguments for the function to be applied
**dec_kwargs: keyword arguments for the function to be applied
Returns:
Return:
the decorated function
"""
@ -42,12 +42,13 @@ def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callab
def _apply_to_outputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
"""
Decorator function to apply a function to all outputs of a function.
Args:
func_to_apply: the function to apply to the outputs
*dec_args: positional arguments for the function to be applied
**dec_kwargs: keyword arguments for the function to be applied
Returns:
Return:
the decorated function
"""
@ -69,9 +70,8 @@ def _convert_to_tensor(data: Any) -> Any:
Args:
data: the data to convert to tensor
Returns:
Return:
the converted data
"""
if isinstance(data, numbers.Number):
return torch.tensor([data])
@ -86,12 +86,12 @@ def _convert_to_tensor(data: Any) -> Any:
def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
"""Convert all tensors and numpy arrays to numpy arrays.
Args:
data: the tensor or array to convert to numpy
Returns:
Return:
the resulting numpy array
"""
if isinstance(data, torch.Tensor):
return data.cpu().detach().numpy()
@ -103,6 +103,33 @@ def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) ->
raise TypeError("The given type ('%s') cannot be converted to a numpy array!" % type(data).__name__)
def _numpy_metric_input_conversion(func_to_decorate: Callable) -> Callable:
"""
Decorator converting all inputs of a function to numpy
Args:
func_to_decorate: the function whose inputs shall be converted
Return:
Callable: the decorated function
"""
return _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate)
def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable:
"""
Decorator converting all outputs of a function to tensors
Args:
func_to_decorate: the function whose outputs shall be converted
Return:
Callable: the decorated function
"""
return _apply_to_outputs(_convert_to_tensor)(func_to_decorate)
def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable:
"""
Decorator handling the argument conversion for metrics working on numpy.
@ -112,19 +139,45 @@ def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable:
Args:
func_to_decorate: the function whose inputs and outputs shall be converted
Returns:
Return:
the decorated function
"""
# applies collection conversion from tensor to numpy to all inputs
# we need to include numpy arrays here, since otherwise they will also be treated as sequences
func_convert_inputs = _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate)
func_convert_inputs = _numpy_metric_input_conversion(func_to_decorate)
# converts all inputs back to tensors (device doesn't matter here, since this is handled by BaseMetric)
func_convert_in_out = _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)
func_convert_in_out = _tensor_metric_output_conversion(func_convert_inputs)
return func_convert_in_out
def _tensor_metric_input_conversion(func_to_decorate: Callable) -> Callable:
"""
Decorator converting all inputs of a function to tensors
Args:
func_to_decorate: the function whose inputs shall be converted
Return:
Callable: the decorated function
"""
return _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_tensor)(func_to_decorate)
def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> Callable:
"""
Decorator converting all numpy arrays and numbers occuring in the outputs of a function to tensors
Args:
func_to_decorate: the function whose outputs shall be converted
Return:
Callable: the decorated function
"""
return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number),
_convert_to_tensor)(func_to_decorate)
def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
"""
Decorator Handling the argument conversion for metrics working on tensors.
@ -133,16 +186,33 @@ def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
Args:
func_to_decorate: the function whose inputs and outputs shall be converted
Returns:
Return:
the decorated function
"""
# converts all inputs to tensor if possible
# we need to include tensors here, since otherwise they will also be treated as sequences
func_convert_inputs = _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_tensor)(func_to_decorate)
func_convert_inputs = _tensor_metric_input_conversion(func_to_decorate)
# convert all outputs to tensor if possible
return _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)
return _tensor_metric_output_conversion(func_convert_inputs)
def _tensor_collection_metric_conversion(func_to_decorate: Callable) -> Callable:
"""
Decorator Handling the argument conversion for metrics working on tensors.
All inputs of the decorated function and all numpy arrays and numbers in
it's outputs will be converted to tensors
Args:
func_to_decorate: the function whose inputs and outputs shall be converted
Return:
the decorated function
"""
# converts all inputs to tensor if possible
# we need to include tensors here, since otherwise they will also be treated as sequences
func_convert_inputs = _tensor_metric_input_conversion(func_to_decorate)
# convert all outputs to tensor if possible
return _tensor_collection_metric_output_conversion(func_convert_inputs)
def _sync_ddp_if_available(result: Union[torch.Tensor],
@ -157,9 +227,8 @@ def _sync_ddp_if_available(result: Union[torch.Tensor],
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum.
Returns:
Return:
reduced value
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
@ -177,11 +246,32 @@ def _sync_ddp_if_available(result: Union[torch.Tensor],
return result
def sync_ddp(group: Optional[Any] = None,
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
"""
This decorator syncs a functions outputs across different processes for DDP.
Args:
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum
Return:
the decorated function
"""
def decorator_fn(func_to_decorate):
return _apply_to_outputs(apply_to_collection, torch.Tensor,
_sync_ddp_if_available, group=group,
reduce_op=reduce_op)(func_to_decorate)
return decorator_fn
def numpy_metric(group: Optional[Any] = None,
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on numpy arrays.
It handles the argument conversion and DDP reduction for metrics working on numpy.
All inputs of the decorated function will be converted to numpy and all
outputs will be converted to tensors.
@ -191,15 +281,12 @@ def numpy_metric(group: Optional[Any] = None,
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum
Returns:
Return:
the decorated function
"""
def decorator_fn(func_to_decorate):
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available,
group=group,
reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate))
return sync_ddp(group=group, reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate))
return decorator_fn
@ -208,7 +295,6 @@ def tensor_metric(group: Optional[Any] = None,
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on tensors.
It handles the argument conversion and DDP reduction for metrics working on tensors.
All inputs and outputs of the decorated function will be converted to tensors.
In DDP Training all output tensors will be reduced according to the given rules.
@ -217,14 +303,34 @@ def tensor_metric(group: Optional[Any] = None,
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum
Returns:
Return:
the decorated function
"""
def decorator_fn(func_to_decorate):
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available,
group=group,
reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate))
return sync_ddp(group=group, reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate))
return decorator_fn
def tensor_collection_metric(group: Optional[Any] = None,
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on tensors and returning collections
that cannot be converted to tensors.
It handles the argument conversion and DDP reduction for metrics working on tensors.
All inputs and outputs of the decorated function will be converted to tensors.
In DDP Training all output tensors will be reduced according to the given rules.
Args:
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum
Return:
the decorated function
"""
def decorator_fn(func_to_decorate):
return sync_ddp(group=group, reduce_op=reduce_op)(_tensor_collection_metric_conversion(func_to_decorate))
return decorator_fn

View File

@ -0,0 +1,693 @@
from collections import Sequence
from functools import wraps
from typing import Optional, Tuple, Callable
import torch
from pytorch_lightning.metrics.functional.reduction import reduce
def to_onehot(
tensor: torch.Tensor,
n_classes: Optional[int] = None,
) -> torch.Tensor:
"""
Converts a dense label tensor to one-hot format
Args:
tensor: dense label tensor, with shape [N, d1, d2, ...]
n_classes: number of classes C
Output:
A sparse label tensor with shape [N, C, d1, d2, ...]
"""
if n_classes is None:
n_classes = int(tensor.max().detach().item() + 1)
dtype, device, shape = tensor.dtype, tensor.device, tensor.shape
tensor_onehot = torch.zeros(shape[0], n_classes, *shape[1:],
dtype=dtype, device=device)
index = tensor.long().unsqueeze(1).expand_as(tensor_onehot)
return tensor_onehot.scatter_(1, index, 1.0)
def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
"""
Converts a tensor of probabilities to a dense label tensor
Args:
tensor: probabilities to get the categorical label [N, d1, d2, ...]
argmax_dim: dimension to apply (default: 1)
Return:
A tensor with categorical labels [N, d2, ...]
"""
return torch.argmax(tensor, dim=argmax_dim)
def get_num_classes(
pred: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int],
) -> int:
"""
Returns the number of classes for a given prediction and target tensor.
Args:
pred: predicted values
target: true labels
num_classes: number of classes if known (default: None)
Return:
An integer that represents the number of classes.
"""
if num_classes is None:
if pred.ndim > target.ndim:
num_classes = pred.size(1)
else:
num_classes = int(target.max().detach().item() + 1)
return num_classes
def stat_scores(
pred: torch.Tensor,
target: torch.Tensor,
class_index: int, argmax_dim: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Calculates the number of true positive, falsepositivee, true negative
and false negative for a specific class
Args:
pred: prediction tensor
target: target tensor
class_index: class to calculate over
argmax_dim: if pred is a tensor of probabilities, this indicates the
axis the argmax transformation will be applied over
Return:
Tensors in the following order: True Positive, False Positive, True Negative, False Negative
"""
if pred.ndim == target.ndim + 1:
pred = to_categorical(pred, argmax_dim=argmax_dim)
tp = ((pred == class_index) * (target == class_index)).to(torch.long).sum()
fp = ((pred == class_index) * (target != class_index)).to(torch.long).sum()
tn = ((pred != class_index) * (target != class_index)).to(torch.long).sum()
fn = ((pred != class_index) * (target == class_index)).to(torch.long).sum()
return tp, fp, tn, fn
def stat_scores_multiple_classes(
pred: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
argmax_dim: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Calls the stat_scores function iteratively for all classes, thus
calculating the number of true postive, false postive, true negative
and false negative for each class
Args:
pred: prediction tensor
target: target tensor
class_index: class to calculate over
argmax_dim: if pred is a tensor of probabilities, this indicates the
axis the argmax transformation will be applied over
Return:
Returns tensors for: tp, fp, tn, fn
"""
num_classes = get_num_classes(pred=pred, target=target,
num_classes=num_classes)
if pred.ndim == target.ndim + 1:
pred = to_categorical(pred, argmax_dim=argmax_dim)
tps = torch.zeros((num_classes,), device=pred.device)
fps = torch.zeros((num_classes,), device=pred.device)
tns = torch.zeros((num_classes,), device=pred.device)
fns = torch.zeros((num_classes,), device=pred.device)
for c in range(num_classes):
tps[c], fps[c], tns[c], fns[c] = stat_scores(pred=pred, target=target,
class_index=c)
return tps, fps, tns, fns
def accuracy(
pred: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
reduction='elementwise_mean',
) -> torch.Tensor:
"""
Computes the accuracy classification score
Args:
pred: predicted labels
target: ground truth labels
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
Return:
A Tensor with the classification score.
"""
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred, target=target,
num_classes=num_classes)
if not (target > 0).any() and num_classes is None:
raise RuntimeError("cannot infer num_classes when target is all zero")
accuracies = (tps + tns) / (tps + tns + fps + fns)
return reduce(accuracies, reduction=reduction)
def confusion_matrix(
pred: torch.Tensor,
target: torch.Tensor,
normalize: bool = False,
) -> torch.Tensor:
"""
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
in group i that were predicted in group j.
Args:
pred: estimated targets
target: ground truth labels
normalize: normalizes confusion matrix
Return:
Tensor, confusion matrix C [num_classes, num_classes ]
"""
num_classes = get_num_classes(pred, target, None)
d = target.size(-1)
batch_vec = torch.arange(target.size(-1))
# this will account for multilabel
unique_labels = batch_vec * num_classes ** 2 + target.view(-1) * num_classes + pred.view(-1)
bins = torch.bincount(unique_labels, minlength=d * num_classes ** 2)
cm = bins.reshape(d, num_classes, num_classes).squeeze().float()
if normalize:
cm = cm / cm.sum(-1)
return cm
def precision_recall(
pred: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes precision and recall for different thresholds
Args:
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing precision-recall values (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
Return:
Tensor with precision and recall
"""
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred,
target=target,
num_classes=num_classes)
tps = tps.to(torch.float)
fps = fps.to(torch.float)
fns = fns.to(torch.float)
precision = tps / (tps + fps)
recall = tps / (tps + fns)
precision = reduce(precision, reduction=reduction)
recall = reduce(recall, reduction=reduction)
return precision, recall
def precision(
pred: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
) -> torch.Tensor:
"""
Computes precision score.
Args:
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing precision values (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
Return:
Tensor with precision.
"""
return precision_recall(pred=pred, target=target,
num_classes=num_classes, reduction=reduction)[0]
def recall(
pred: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
) -> torch.Tensor:
"""
Computes recall score.
Args:
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing recall values (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
Return:
Tensor with recall.
"""
return precision_recall(pred=pred, target=target,
num_classes=num_classes, reduction=reduction)[1]
def fbeta_score(
pred: torch.Tensor,
target: torch.Tensor,
beta: float,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
) -> torch.Tensor:
"""
Computes the F-beta score which is a weighted harmonic mean of precision and recall.
It ranges between 1 and 0, where 1 is perfect and the worst value is 0.
Args:
pred: estimated probabilities
target: ground-truth labels
beta: weights recall when combining the score.
beta < 1: more weight to precision.
beta > 1 more weight to recall
beta = 0: only precision
beta -> inf: only recall
num_classes: number of classes
reduction: method for reducing F-score (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements.
Return:
Tensor with the value of F-score. It is a value between 0-1.
"""
prec, rec = precision_recall(pred=pred, target=target,
num_classes=num_classes,
reduction='none')
nom = (1 + beta ** 2) * prec * rec
denom = ((beta ** 2) * prec + rec)
fbeta = nom / denom
return reduce(fbeta, reduction=reduction)
def f1_score(
pred: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
reduction='elementwise_mean',
) -> torch.Tensor:
"""
Computes F1-score a.k.a F-measure.
Args:
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing F1-score (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements.
Return:
Tensor containing F1-score
"""
return fbeta_score(pred=pred, target=target, beta=1.,
num_classes=num_classes, reduction=reduction)
def _binary_clf_curve(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py
"""
if sample_weight is not None and not isinstance(sample_weight, torch.Tensor):
sample_weight = torch.tensor(sample_weight, device=pred.device, dtype=torch.float)
# remove class dimension if necessary
if pred.ndim > target.ndim:
pred = pred[:, 0]
desc_score_indices = torch.argsort(pred, descending=True)
pred = pred[desc_score_indices]
target = target[desc_score_indices]
if sample_weight is not None:
weight = sample_weight[desc_score_indices]
else:
weight = 1.
# pred typically has many tied values. Here we extract
# the indices associated with the distinct values. We also
# concatenate a value for the end of the curve.
distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0]
threshold_idxs = torch.cat([distinct_value_indices,
torch.tensor([target.size(0) - 1])])
target = (target == pos_label).to(torch.long)
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]
if sample_weight is not None:
# express fps as a cumsum to ensure fps is increasing even in
# the presence of floating point errors
fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs]
else:
fps = 1 + threshold_idxs - tps
return fps, tps, pred[threshold_idxs]
def roc(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
pos_label: the label for the positive class (default: 1)
Return:
[Tensor, Tensor, Tensor]: false-positive rate (fpr), true-positive rate (tpr), thresholds
"""
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=pos_label)
# Add an extra threshold position
# to make sure that the curve starts at (0, 0)
tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])
if fps[-1] <= 0:
raise ValueError("No negative samples in targets, false positive value should be meaningless")
fpr = fps / fps[-1]
if tps[-1] <= 0:
raise ValueError("No positive samples in targets, true positive value should be meaningless")
tpr = tps / tps[-1]
return fpr, tpr, thresholds
def multiclass_roc(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
num_classes: Optional[int] = None,
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Computes the Receiver Operating Characteristic (ROC) for multiclass predictors.
Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
num_classes: number of classes (default: None, computes automatically from data)
Return:
[num_classes, Tensor, Tensor, Tensor]: returns roc for each class.
number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds
"""
num_classes = get_num_classes(pred, target, num_classes)
class_roc_vals = []
for c in range(num_classes):
pred_c = pred[:, c]
class_roc_vals.append(roc(pred=pred_c, target=target,
sample_weight=sample_weight, pos_label=c))
return tuple(class_roc_vals)
def precision_recall_curve(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes precision-recall pairs for different thresholds.
Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
pos_label: the label for the positive class (default: 1.)
Return:
[Tensor, Tensor, Tensor]: precision, recall, thresholds
"""
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=pos_label)
precision = tps / (tps + fps)
recall = tps / tps[-1]
# stop when full recall attained
# and reverse the outputs so recall is decreasing
last_ind = torch.where(tps == tps[-1])[0][0]
sl = slice(0, last_ind.item() + 1)
# need to call reversed explicitly, since including that to slice would
# introduce negative strides that are not yet supported in pytorch
precision = torch.cat([reversed(precision[sl]),
torch.ones(1, dtype=precision.dtype,
device=precision.device)])
recall = torch.cat([reversed(recall[sl]),
torch.zeros(1, dtype=recall.dtype,
device=recall.device)])
thresholds = torch.tensor(reversed(thresholds[sl]))
return precision, recall, thresholds
def multiclass_precision_recall_curve(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
num_classes: Optional[int] = None,
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Computes precision-recall pairs for different thresholds given a multiclass scores.
Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weight
num_classes: number of classes
Return:
[num_classes, Tensor, Tensor, Tensor]: number of classes, precision, recall, thresholds
"""
num_classes = get_num_classes(pred, target, num_classes)
class_pr_vals = []
for c in range(num_classes):
pred_c = pred[:, c]
class_pr_vals.append(precision_recall_curve(
pred=pred_c,
target=target,
sample_weight=sample_weight, pos_label=c))
return tuple(class_pr_vals)
def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True):
"""
Computes Area Under the Curve (AUC) using the trapezoidal rule
Args:
x: x-coordinates
y: y-coordinates
reorder: reorder coordinates, so they are increasing.
Return:
AUC score (float)
"""
direction = 1.
if reorder:
# can't use lexsort here since it is not implemented for torch
order = torch.argsort(x)
x, y = x[order], y[order]
else:
dx = x[1:] - x[:-1]
if (dx < 0).any():
if (dx, 0).all():
direction = -1.
else:
raise ValueError("Reordering is not turned on, and "
"the x array is not increasing: %s" % x)
return direction * torch.trapz(y, x)
def auc_decorator(reorder: bool = True) -> Callable:
def wrapper(func_to_decorate: Callable) -> Callable:
@wraps(func_to_decorate)
def new_func(*args, **kwargs) -> torch.Tensor:
x, y = func_to_decorate(*args, **kwargs)[:2]
return auc(x, y, reorder=reorder)
return new_func
return wrapper
def multiclass_auc_decorator(reorder: bool = True) -> Callable:
def wrapper(func_to_decorate: Callable) -> Callable:
def new_func(*args, **kwargs) -> torch.Tensor:
results = []
for class_result in func_to_decorate(*args, **kwargs):
x, y = class_result[:2]
results.append(auc(x, y, reorder=reorder))
return torch.cat(results)
return new_func
return wrapper
def auroc(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
) -> torch.Tensor:
"""
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores
Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
pos_label: the label for the positive class (default: 1.)
"""
@auc_decorator(reorder=True)
def _auroc(pred, target, sample_weight, pos_label):
return roc(pred, target, sample_weight, pos_label)
return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label)
def average_precision(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
) -> torch.Tensor:
precision, recall, _ = precision_recall_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=pos_label)
# Return the step function integral
# The following works because the last entry of precision is
# guaranteed to be 1, as returned by precision_recall_curve
return -torch.sum(recall[1:] - recall[:-1] * precision[:-1])
def dice_score(
pred: torch.Tensor,
target: torch.Tensor,
bg: bool = False,
nan_score: float = 0.0,
no_fg_score: float = 0.0,
reduction: str = 'elementwise_mean',
) -> torch.Tensor:
n_classes = pred.shape[1]
bg = (1 - int(bool(bg)))
scores = torch.zeros(n_classes - bg, device=pred.device, dtype=torch.float32)
for i in range(bg, n_classes):
if not (target == i).any():
# no foreground class
scores[i - bg] += no_fg_score
continue
tp, fp, tn, fn = stat_scores(pred=pred, target=target, class_index=i)
denom = (2 * tp + fp + fn).to(torch.float)
if torch.isclose(denom, torch.zeros_like(denom)).any():
# nan result
score_cls = nan_score
else:
score_cls = (2 * tp).to(torch.float) / denom
scores[i - bg] += score_cls
return reduce(scores, reduction=reduction)

View File

@ -0,0 +1,24 @@
import torch
def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
"""
Reduces a given tensor by a given reduction method
Args:
to_reduce : the tensor, which shall be reduced
reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum')
Return:
reduced Tensor
Raise:
ValueError if an invalid reduction parameter was given
"""
if reduction == 'elementwise_mean':
return torch.mean(to_reduce)
if reduction == 'none':
return to_reduce
if reduction == 'sum':
return torch.sum(to_reduce)
raise ValueError('Reduction parameter unknown.')

View File

@ -3,16 +3,16 @@ from typing import Any, Optional
import torch
import torch.distributed
from torch.nn import Module
from pytorch_lightning.metrics.converters import tensor_metric, numpy_metric
from pytorch_lightning.metrics.converters import (
tensor_metric, numpy_metric, tensor_collection_metric)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
__all__ = ['Metric', 'TensorMetric', 'NumpyMetric']
class Metric(ABC, DeviceDtypeModuleMixin, Module):
class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC):
"""
Abstract base class for metric implementation.
@ -20,6 +20,7 @@ class Metric(ABC, DeviceDtypeModuleMixin, Module):
1. Return multiple Outputs
2. Handle their own DDP sync
"""
def __init__(self, name: str):
"""
Args:
@ -49,6 +50,7 @@ class TensorMetric(Metric):
All inputs and outputs will be casted to tensors if necessary.
Already handles DDP sync and input/output conversions.
"""
def __init__(self, name: str,
reduce_group: Optional[Any] = None,
reduce_op: Optional[Any] = None):
@ -73,6 +75,47 @@ class TensorMetric(Metric):
_to_device_dtype)
class TensorCollectionMetric(Metric):
"""
Base class for metric implementation operating directly on tensors.
All inputs will be casted to tensors if necessary. Outputs won't be casted.
Already handles DDP sync and input conversions.
This class differs from :class:`TensorMetric`, as it assumes all outputs to
be collections of tensors and does not explicitly convert them. This is
necessary, since some collections (like for ROC, Precision-Recall Curve etc.)
cannot be converted to tensors at the highest level.
All numpy arrays and numbers occuring in these outputs will still be converted.
Use this class as a baseclass, whenever you want to ensure inputs are
tensors and outputs cannot be converted to tensors automatically
"""
def __init__(self, name: str,
reduce_group: Optional[Any] = None,
reduce_op: Optional[Any] = None):
"""
Args:
name: the metric's name
reduce_group: the process group for DDP reduces (only needed for DDP training).
Defaults to all processes (world)
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
Defaults to sum.
"""
super().__init__(name)
self._orig_call = tensor_collection_metric(group=reduce_group,
reduce_op=reduce_op)(super().__call__)
def __call__(self, *args, **kwargs) -> torch.Tensor:
def _to_device_dtype(x: torch.Tensor) -> torch.Tensor:
return x.to(device=self.device, dtype=self.dtype, non_blocking=True)
return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor,
_to_device_dtype)
class NumpyMetric(Metric):
"""
Base class for metric implementation operating on numpy arrays.
@ -80,6 +123,7 @@ class NumpyMetric(Metric):
be casted to tensors if necessary.
Already handles DDP sync and input/output conversions.
"""
def __init__(self, name: str,
reduce_group: Optional[Any] = None,
reduce_op: Optional[Any] = None):

View File

@ -1,130 +0,0 @@
import numbers
from typing import Union, Any, Optional
import numpy as np
import torch
from torch.utils.data._utils.collate import default_convert
from pytorch_lightning.utilities.apply_func import apply_to_collection
def _apply_to_inputs(func_to_apply, *dec_args, **dec_kwargs):
def decorator_fn(func_to_decorate):
def new_func(*args, **kwargs):
args = func_to_apply(args, *dec_args, **dec_kwargs)
kwargs = func_to_apply(kwargs, *dec_args, **dec_kwargs)
return func_to_decorate(*args, **kwargs)
return new_func
return decorator_fn
def _apply_to_outputs(func_to_apply, *dec_args, **dec_kwargs):
def decorator_fn(function_to_decorate):
def new_func(*args, **kwargs):
result = function_to_decorate(*args, **kwargs)
return func_to_apply(result, *dec_args, **dec_kwargs)
return new_func
return decorator_fn
def _convert_to_tensor(data: Any) -> Any:
"""
Maps all kind of collections and numbers to tensors
Args:
data: the data to convert to tensor
Returns:
the converted data
"""
if isinstance(data, numbers.Number):
return torch.tensor([data])
else:
return default_convert(data)
def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
"""
converts all tensors and numpy arrays to numpy arrays
Args:
data: the tensor or array to convert to numpy
Returns:
the resulting numpy array
"""
if isinstance(data, torch.Tensor):
return data.cpu().detach().numpy()
elif isinstance(data, numbers.Number):
return np.array([data])
return data
def _numpy_metric_conversion(func_to_decorate):
# Applies collection conversion from tensor to numpy to all inputs
# we need to include numpy arrays here, since otherwise they will also be treated as sequences
func_convert_inputs = _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate)
# converts all inputs back to tensors (device doesn't matter here, since this is handled by BaseMetric)
func_convert_in_out = _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)
return func_convert_in_out
def _tensor_metric_conversion(func_to_decorate):
# Converts all inputs to tensor if possible
func_convert_inputs = _apply_to_inputs(_convert_to_tensor)(func_to_decorate)
# convert all outputs to tensor if possible
return _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)
def _sync_ddp(result: Union[torch.Tensor],
group: Any = torch.distributed.group.WORLD,
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM,
) -> torch.Tensor:
"""
Function to reduce the tensors from several ddp processes to one master process
Args:
result: the value to sync and reduce (typically tensor or number)
device: the device to put the synced and reduced value to
dtype: the datatype to convert the synced and reduced value to
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum
Returns:
reduced value
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
# sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=reduce_op, group=group,
async_op=False)
return result
def numpy_metric(group: Any = torch.distributed.group.WORLD,
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM):
def decorator_fn(func_to_decorate):
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp,
group=group,
reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate))
return decorator_fn
def tensor_metric(group: Any = torch.distributed.group.WORLD,
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM):
def decorator_fn(func_to_decorate):
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp,
group=group,
reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate))
return decorator_fn

View File

@ -0,0 +1,309 @@
import pytest
import torch
from pytorch_lightning import seed_everything
from pytorch_lightning.metrics.functional.classification import (
to_onehot,
to_categorical,
get_num_classes,
stat_scores,
stat_scores_multiple_classes,
accuracy,
confusion_matrix,
precision,
recall,
fbeta_score,
f1_score,
_binary_clf_curve,
dice_score,
average_precision,
auroc,
precision_recall_curve,
roc,
auc,
)
def test_onehot():
test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
expected = torch.tensor([
[
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]
], [
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1]
]
])
assert test_tensor.shape == (2, 5)
assert expected.shape == (2, 10, 5)
onehot_classes = to_onehot(test_tensor, n_classes=10)
onehot_no_classes = to_onehot(test_tensor)
assert torch.allclose(onehot_classes, onehot_no_classes)
assert onehot_classes.shape == expected.shape
assert onehot_no_classes.shape == expected.shape
assert torch.allclose(expected.to(onehot_no_classes), onehot_no_classes)
assert torch.allclose(expected.to(onehot_classes), onehot_classes)
def test_to_categorical():
test_tensor = torch.tensor([
[
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0]
], [
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1]
]
]).to(torch.float)
expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
assert expected.shape == (2, 5)
assert test_tensor.shape == (2, 10, 5)
result = to_categorical(test_tensor)
assert result.shape == expected.shape
assert torch.allclose(result, expected.to(result.dtype))
@pytest.mark.parametrize(['pred', 'target', 'num_classes', 'expected_num_classes'], [
pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10),
pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
])
def test_get_num_classes(pred, target, num_classes, expected_num_classes):
assert get_num_classes(pred, target, num_classes) == expected_num_classes
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', 'expected_tn', 'expected_fn'], [
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1),
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1)
])
def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn):
tp, fp, tn, fn = stat_scores(pred, target, class_index=4)
assert tp.item() == expected_tp
assert fp.item() == expected_fp
assert tn.item() == expected_tn
assert fn.item() == expected_fn
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', 'expected_tn', 'expected_fn'], [
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]),
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1]),
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]),
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1])
])
def test_stat_scores_multiclass(pred, target, expected_tp, expected_fp, expected_tn, expected_fn):
tp, fp, tn, fn = stat_scores_multiple_classes(pred, target)
assert torch.allclose(torch.tensor(expected_tp).to(tp), tp)
assert torch.allclose(torch.tensor(expected_fp).to(fp), fp)
assert torch.allclose(torch.tensor(expected_tn).to(tn), tn)
assert torch.allclose(torch.tensor(expected_fn).to(fn), fn)
def test_multilabel_accuracy():
# Dense label indicator matrix format
y1 = torch.tensor([[0, 1, 1], [1, 0, 1]])
y2 = torch.tensor([[0, 0, 1], [1, 0, 1]])
assert torch.allclose(accuracy(y1, y2, reduction='none'), torch.tensor([0.8333333134651184] * 2))
assert torch.allclose(accuracy(y1, y1, reduction='none'), torch.tensor([1., 1.]))
assert torch.allclose(accuracy(y2, y2, reduction='none'), torch.tensor([1., 1.]))
assert torch.allclose(accuracy(y2, torch.logical_not(y2), reduction='none'), torch.tensor([0., 0.]))
assert torch.allclose(accuracy(y1, torch.logical_not(y1), reduction='none'), torch.tensor([0., 0.]))
with pytest.raises(RuntimeError):
accuracy(y2, torch.zeros_like(y2), reduction='none')
def test_confusion_matrix():
target = (torch.arange(120) % 3).view(-1, 1)
pred = target.clone()
cm = confusion_matrix(pred, target, normalize=True)
assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]))
pred = torch.zeros_like(pred)
cm = confusion_matrix(pred, target, normalize=True)
assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]]))
@pytest.mark.parametrize(['pred', 'target', 'expected_prec', 'expected_rec'], [
pytest.param(torch.tensor([1., 0., 1., 0.]), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]),
pytest.param(to_onehot(torch.tensor([1., 0., 1., 0.])), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5])
])
def test_precision_recall(pred, target, expected_prec, expected_rec):
prec = precision(pred, target, reduction='none')
rec = recall(pred, target, reduction='none')
assert torch.allclose(torch.tensor(expected_prec).to(prec), prec)
assert torch.allclose(torch.tensor(expected_rec).to(rec), rec)
@pytest.mark.parametrize(['pred', 'target', 'beta', 'exp_score'], [
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 0.5, [0.5, 0.5]),
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 1, [0.5, 0.5]),
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 2, [0.5, 0.5]),
])
def test_fbeta_score(pred, target, beta, exp_score):
score = fbeta_score(torch.tensor(pred), torch.tensor(target), beta, reduction='none')
assert torch.allclose(score, torch.tensor(exp_score))
score = fbeta_score(to_onehot(torch.tensor(pred)), torch.tensor(target), beta, reduction='none')
assert torch.allclose(score, torch.tensor(exp_score))
@pytest.mark.parametrize(['pred', 'target', 'exp_score'], [
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], [0.5, 0.5]),
])
def test_f1_score(pred, target, exp_score):
score = f1_score(torch.tensor(pred), torch.tensor(target), reduction='none')
assert torch.allclose(score, torch.tensor(exp_score))
score = f1_score(to_onehot(torch.tensor(pred)), torch.tensor(target), reduction='none')
assert torch.allclose(score, torch.tensor(exp_score))
@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [
pytest.param(1, 1., 42),
pytest.param(None, 1., 42),
])
def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
# TODO: move back the pred and target to test func arguments
# if you fix the array inside the function, you'd also have fix the shape,
# because when the array changes, you also have to fix the shape
seed_everything(0)
pred = torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100
target = torch.tensor([0, 1] * 50, dtype=torch.int)
if sample_weight is not None:
sample_weight = torch.ones_like(pred) * sample_weight
fps, tps, thresh = _binary_clf_curve(pred, target, sample_weight, pos_label)
assert isinstance(tps, torch.Tensor)
assert isinstance(fps, torch.Tensor)
assert isinstance(thresh, torch.Tensor)
assert tps.shape == (exp_shape,)
assert fps.shape == (exp_shape,)
assert thresh.shape == (exp_shape,)
@pytest.mark.parametrize(['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], [
pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4])
])
def test_pr_curve(pred, target, expected_p, expected_r, expected_t):
p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target))
assert p.size() == r.size()
assert p.size(0) == t.size(0) + 1
assert torch.allclose(p, torch.tensor(expected_p).to(p))
assert torch.allclose(r, torch.tensor(expected_r).to(r))
assert torch.allclose(t, torch.tensor(expected_t).to(t))
@pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [
pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]),
pytest.param([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]),
pytest.param([1, 1], [1, 0], [0, 1], [0, 1]),
pytest.param([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]),
pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]),
])
def test_roc_curve(pred, target, expected_tpr, expected_fpr):
fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target))
assert fpr.shape == tpr.shape
assert fpr.size(0) == thresh.size(0)
assert torch.allclose(fpr, torch.tensor(expected_fpr).to(fpr))
assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr))
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
pytest.param([0, 0, 1, 1], [0, 0, 1, 1], 1.),
pytest.param([1, 1, 0, 0], [0, 0, 1, 1], 0.),
pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.5),
pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 1.),
pytest.param([0.5, 0.5, 0.5, 0.5], [1, 1, 0, 0], 0.5),
])
def test_auroc(pred, target, expected):
score = auroc(torch.tensor(pred), torch.tensor(target)).item()
assert score == expected
@pytest.mark.parametrize(['x', 'y', 'expected'], [
pytest.param([0, 1], [0, 1], 0.5),
pytest.param([1, 0], [0, 1], 0.5),
pytest.param([1, 0, 0], [0, 1, 1], 0.5),
pytest.param([0, 1], [1, 1], 1),
pytest.param([0, 0.5, 1], [0, 0.5, 1], 0.5),
])
def test_auc(x, y, expected):
# Test Area Under Curve (AUC) computation
assert auc(torch.tensor(x), torch.tensor(y)) == expected
def test_average_precision_constant_values():
# Check the average_precision_score of a constant predictor is
# the TPR
# Generate a dataset with 25% of positives
target = torch.zeros(100, dtype=torch.float)
target[::4] = 1
# And a constant score
pred = torch.ones(100)
# The precision is then the fraction of positive whatever the recall
# is, as there is only one threshold:
assert average_precision(pred, target).item() == .25
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.),
pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.),
pytest.param([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3),
pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.),
])
def test_dice_score(pred, target, expected):
score = dice_score(torch.tensor(pred), torch.tensor(target))
assert score == expected
# example data taken from
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py

View File

@ -0,0 +1,15 @@
import pytest
import torch
from pytorch_lightning.metrics.functional.reduction import reduce
def test_reduce():
start_tensor = torch.rand(50, 40, 30)
assert torch.allclose(reduce(start_tensor, 'elementwise_mean'), torch.mean(start_tensor))
assert torch.allclose(reduce(start_tensor, 'sum'), torch.sum(start_tensor))
assert torch.allclose(reduce(start_tensor, 'none'), start_tensor)
with pytest.raises(ValueError):
reduce(start_tensor, 'error_reduction')

View File

@ -0,0 +1,227 @@
# NOTE: This file only tests if modules with arguments are running fine.
# The actual metric implementation is tested in functional/test_classification.py
# Especially reduction and reducing across processes won't be tested here!
import pytest
import torch
from pytorch_lightning.metrics.classification import (
Accuracy,
ConfusionMatrix,
PrecisionRecall,
Precision,
Recall,
AveragePrecision,
AUROC,
FBeta,
F1,
ROC,
MulticlassROC,
MulticlassPrecisionRecall,
DiceCoefficient,
)
@pytest.fixture
def random():
torch.manual_seed(0)
@pytest.mark.parametrize('num_classes', [1, None])
def test_accuracy(num_classes):
acc = Accuracy(num_classes=num_classes)
assert acc.name == 'accuracy'
result = acc(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]),
target=torch.tensor([[0, 0, 1], [1, 0, 1]]))
assert isinstance(result, torch.Tensor)
@pytest.mark.parametrize('normalize', [False, True])
def test_confusion_matrix(normalize):
conf_matrix = ConfusionMatrix(normalize=normalize)
assert conf_matrix.name == 'confusion_matrix'
target = (torch.arange(120) % 3).view(-1, 1)
pred = target.clone()
cm = conf_matrix(pred, target)
assert isinstance(cm, torch.Tensor)
@pytest.mark.parametrize('pos_label', [1, 2.])
def test_precision_recall(pos_label):
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1])
pr_curve = PrecisionRecall(pos_label=pos_label)
assert pr_curve.name == 'precision_recall_curve'
pr = pr_curve(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
assert isinstance(pr, tuple)
assert len(pr) == 3
for tmp in pr:
assert isinstance(tmp, torch.Tensor)
@pytest.mark.parametrize('num_classes', [1, None])
def test_precision(num_classes):
precision = Precision(num_classes=num_classes)
assert precision.name == 'precision'
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1])
prec = precision(pred=pred, target=target)
assert isinstance(prec, torch.Tensor)
@pytest.mark.parametrize('num_classes', [1, None])
def test_recall(num_classes):
recall = Recall(num_classes=num_classes)
assert recall.name == 'recall'
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1])
rec = recall(pred=pred, target=target)
assert isinstance(rec, torch.Tensor)
@pytest.mark.parametrize('pos_label', [1, 2])
def test_average_precision(pos_label):
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 0, 1])
avg_prec = AveragePrecision(pos_label=pos_label)
assert avg_prec.name == 'AP'
ap = avg_prec(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
assert isinstance(ap, torch.Tensor)
@pytest.mark.parametrize('pos_label', [1, 2])
def test_auroc(pos_label):
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 0, 1])
auroc = AUROC(pos_label=pos_label)
assert auroc.name == 'auroc'
area = auroc(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
assert isinstance(area, torch.Tensor)
@pytest.mark.parametrize(['beta', 'num_classes'], [
pytest.param(0., 1),
pytest.param(0.5, 1),
pytest.param(1., 1),
pytest.param(2., 1),
pytest.param(0., None),
pytest.param(0.5, None),
pytest.param(1., None),
pytest.param(2., None)
])
def test_fbeta(beta, num_classes):
fbeta = FBeta(beta=beta, num_classes=num_classes)
assert fbeta.name == 'fbeta'
score = fbeta(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]),
target=torch.tensor([[0, 0, 1], [1, 0, 1]]))
assert isinstance(score, torch.Tensor)
@pytest.mark.parametrize('num_classes', [1, None])
def test_f1(num_classes):
f1 = F1(num_classes=num_classes)
assert f1.name == 'f1'
score = f1(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]),
target=torch.tensor([[0, 0, 1], [1, 0, 1]]))
assert isinstance(score, torch.Tensor)
@pytest.mark.parametrize('pos_label', [1, 2])
def test_roc(pos_label):
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 4, 3])
roc = ROC(pos_label=pos_label)
assert roc.name == 'roc'
res = roc(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
assert isinstance(res, tuple)
assert len(res) == 3
for tmp in res:
assert isinstance(tmp, torch.Tensor)
@pytest.mark.parametrize('num_classes', [4, None])
def test_multiclass_roc(num_classes):
pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
[0.05, 0.85, 0.05, 0.05],
[0.05, 0.05, 0.85, 0.05],
[0.05, 0.05, 0.05, 0.85]])
target = torch.tensor([0, 1, 3, 2])
multi_roc = MulticlassROC(num_classes=num_classes)
assert multi_roc.name == 'multiclass_roc'
res = multi_roc(pred, target)
assert isinstance(res, tuple)
if num_classes is not None:
assert len(res) == num_classes
for tmp in res:
assert isinstance(tmp, tuple)
assert len(tmp) == 3
for _tmp in tmp:
assert isinstance(_tmp, torch.Tensor)
@pytest.mark.parametrize('num_classes', [4, None])
def test_multiclass_pr(num_classes):
pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
[0.05, 0.85, 0.05, 0.05],
[0.05, 0.05, 0.85, 0.05],
[0.05, 0.05, 0.05, 0.85]])
target = torch.tensor([0, 1, 3, 2])
multi_pr = MulticlassPrecisionRecall(num_classes=num_classes)
assert multi_pr.name == 'multiclass_precision_recall_curve'
pr = multi_pr(pred, target)
assert isinstance(pr, tuple)
if num_classes is not None:
assert len(pr) == num_classes
for tmp in pr:
assert isinstance(tmp, tuple)
assert len(tmp) == 3
for _tmp in tmp:
assert isinstance(_tmp, torch.Tensor)
@pytest.mark.parametrize('include_background', [True, False])
def test_dice_coefficient(include_background):
dice_coeff = DiceCoefficient(include_background=include_background)
assert dice_coeff.name == 'dice'
dice = dice_coeff(torch.randint(0, 1, (10, 25, 25)),
torch.randint(0, 1, (10, 25, 25)))
assert isinstance(dice, torch.Tensor)

View File

@ -6,16 +6,19 @@ import torch.multiprocessing as mp
import tests.base.utils as tutils
from pytorch_lightning.metrics.converters import (
_apply_to_inputs, _apply_to_outputs, _convert_to_tensor, _convert_to_numpy,
_numpy_metric_conversion, _tensor_metric_conversion, _sync_ddp_if_available, tensor_metric, numpy_metric)
_apply_to_inputs,
_apply_to_outputs,
_convert_to_tensor,
_convert_to_numpy,
_numpy_metric_conversion,
_tensor_metric_conversion,
_sync_ddp_if_available,
tensor_metric,
numpy_metric
)
@pytest.mark.parametrize(['args', 'kwargs'],
[pytest.param([], {}),
pytest.param([1., 2.], {}),
pytest.param([], {'a': 1., 'b': 2.}),
pytest.param([1., 2.], {'a': 1., 'b': 2.})])
def test_apply_to_inputs(args, kwargs):
def test_apply_to_inputs():
def apply_fn(inputs, factor):
if isinstance(inputs, (float, int)):
return inputs * factor
@ -25,22 +28,24 @@ def test_apply_to_inputs(args, kwargs):
return [apply_fn(x, factor) for x in inputs]
@_apply_to_inputs(apply_fn, factor=2.)
def test_fn(*func_args, **func_kwargs):
return func_args, func_kwargs
def test_fn(*args, **kwargs):
return args, kwargs
result_args, result_kwargs = test_fn(*args, **kwargs)
assert isinstance(result_args, (list, tuple))
assert isinstance(result_kwargs, dict)
assert len(result_args) == len(args)
assert len(result_kwargs) == len(kwargs)
assert all([k in result_kwargs for k in kwargs.keys()])
for arg, result_arg in zip(args, result_args):
assert arg * 2. == result_arg
for args in [[], [1., 2.]]:
for kwargs in [{}, {'a': 1., 'b': 2.}]:
result_args, result_kwargs = test_fn(*args, **kwargs)
assert isinstance(result_args, (list, tuple))
assert isinstance(result_kwargs, dict)
assert len(result_args) == len(args)
assert len(result_kwargs) == len(kwargs)
assert all([k in result_kwargs for k in kwargs.keys()])
for arg, result_arg in zip(args, result_args):
assert arg * 2. == result_arg
for key in kwargs.keys():
arg = kwargs[key]
result_arg = result_kwargs[key]
assert arg * 2. == result_arg
for key in kwargs.keys():
arg = kwargs[key]
result_arg = result_kwargs[key]
assert arg * 2. == result_arg
def test_apply_to_outputs():
@ -100,7 +105,7 @@ def test_tensor_metric_conversion():
assert result.item() == 5.
def setup_ddp(rank, worldsize, ):
def _setup_ddp(rank, worldsize):
import os
os.environ['MASTER_ADDR'] = 'localhost'
@ -109,8 +114,8 @@ def setup_ddp(rank, worldsize, ):
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
def ddp_test_fn(rank, worldsize):
setup_ddp(rank, worldsize)
def _ddp_test_fn(rank, worldsize):
_setup_ddp(rank, worldsize)
tensor = torch.tensor([1.], device='cuda:0')
reduced_tensor = _sync_ddp_if_available(tensor)
@ -119,6 +124,7 @@ def ddp_test_fn(rank, worldsize):
'Sync-Reduce does not work properly with DDP and Tensors'
@pytest.mark.spawn
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_sync_reduce_ddp():
"""Make sure sync-reduce works with DDP"""
@ -126,7 +132,9 @@ def test_sync_reduce_ddp():
tutils.set_random_master_port()
worldsize = 2
mp.spawn(ddp_test_fn, args=(worldsize,), nprocs=worldsize)
mp.spawn(_ddp_test_fn, args=(worldsize,), nprocs=worldsize)
# dist.destroy_process_group()
def test_sync_reduce_simple():
@ -161,16 +169,18 @@ def _test_tensor_metric(is_ddp: bool):
def _ddp_test_tensor_metric(rank, worldsize):
setup_ddp(rank, worldsize)
_setup_ddp(rank, worldsize)
_test_tensor_metric(True)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_tensor_metric_ddp():
tutils.reset_seed()
tutils.set_random_master_port()
world_size = 2
mp.spawn(_ddp_test_tensor_metric, args=(world_size,), nprocs=world_size)
# dist.destroy_process_group()
def test_tensor_metric_simple():
@ -199,16 +209,19 @@ def _test_numpy_metric(is_ddp: bool):
def _ddp_test_numpy_metric(rank, worldsize):
setup_ddp(rank, worldsize)
_setup_ddp(rank, worldsize)
_test_numpy_metric(True)
@pytest.mark.spawn
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_numpy_metric_ddp():
tutils.reset_seed()
tutils.set_random_master_port()
world_size = 2
mp.spawn(_ddp_test_numpy_metric, args=(world_size,), nprocs=world_size)
# dist.destroy_process_group()
def test_numpy_metric_simple():
_test_tensor_metric(False)
_test_numpy_metric(False)

View File

@ -1,7 +1,7 @@
import numpy as np
import torch
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric, TensorCollectionMetric
class DummyTensorMetric(TensorMetric):
@ -24,7 +24,65 @@ class DummyNumpyMetric(NumpyMetric):
return 1.
class DummyTensorCollectionMetric(TensorCollectionMetric):
def __init__(self):
super().__init__('dummy')
def forward(self, input1, input2):
assert isinstance(input1, torch.Tensor)
assert isinstance(input2, torch.Tensor)
return 1., 2., 3., 4.
def _test_collection_metric(metric: Metric):
""" Test that metric.device, metric.dtype works for metric collection """
input1, input2 = torch.tensor([1.]), torch.tensor([2.])
def change_and_check_device_dtype(device, dtype):
metric.to(device=device, dtype=dtype)
metric_val = metric(input1, input2)
assert not isinstance(metric_val, torch.Tensor)
if device is not None:
assert metric.device in [device, torch.device(device)]
if dtype is not None:
assert metric.dtype == dtype
devices = [None, 'cpu']
if torch.cuda.is_available():
devices += ['cuda:0']
for device in devices:
for dtype in [None, torch.float32, torch.float64]:
change_and_check_device_dtype(device=device, dtype=dtype)
if torch.cuda.is_available():
metric.cuda(0)
assert metric.device == torch.device('cuda', index=0)
metric.cpu()
assert metric.device == torch.device('cpu')
metric.type(torch.int8)
assert metric.dtype == torch.int8
metric.float()
assert metric.dtype == torch.float32
metric.double()
assert metric.dtype == torch.float64
assert all(out.dtype == torch.float64 for out in metric(input1, input2))
if torch.cuda.is_available():
metric.cuda()
metric.half()
assert metric.dtype == torch.float16
def _test_metric(metric: Metric):
""" Test that metric.device, metric.dtype works for single metric"""
input1, input2 = torch.tensor([1.]), torch.tensor([2.])
def change_and_check_device_dtype(device, dtype):
@ -83,3 +141,7 @@ def test_tensor_metric():
def test_numpy_metric():
_test_metric(DummyNumpyMetric())
def test_tensor_collection():
_test_collection_metric(DummyTensorCollectionMetric())

View File

@ -5,13 +5,24 @@ from functools import partial
import numpy as np
import pytest
import torch
from sklearn.metrics import (accuracy_score, average_precision_score, auc, confusion_matrix, f1_score,
fbeta_score, precision_score, recall_score, precision_recall_curve, roc_curve,
roc_auc_score)
from sklearn.metrics import (
accuracy_score,
average_precision_score,
auc,
confusion_matrix,
f1_score,
fbeta_score,
precision_score,
recall_score,
precision_recall_curve,
roc_curve,
roc_auc_score
)
from pytorch_lightning.metrics.converters import _convert_to_numpy
from pytorch_lightning.metrics.sklearn import (Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
from pytorch_lightning.metrics.sklearn import (
Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
from pytorch_lightning.utilities.apply_func import apply_to_collection
@ -25,37 +36,38 @@ def xy_only(func):
@pytest.mark.parametrize(['metric_class', 'sklearn_func', 'inputs'], [
pytest.param(Accuracy(), accuracy_score,
{'y_pred': torch.randint(low=0, high=10, size=(128,)),
'y_true': torch.randint(low=0, high=10, size=(128,))}, id='Accuracy'),
'y_true': torch.randint(low=0, high=10, size=(128,))},
id='Accuracy'),
pytest.param(AUC(), auc, {'x': torch.arange(10, dtype=torch.float) / 10,
'y': torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2,
0.2, 0.3, 0.5, 0.6, 0.7])}, id='AUC'),
'y': torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.3, 0.5, 0.6, 0.7])},
id='AUC'),
pytest.param(AveragePrecision(), average_precision_score,
{'y_score': torch.randint(2, size=(128,)),
'y_true': torch.randint(2, size=(128,))}, id='AveragePrecision'),
{'y_score': torch.randint(2, size=(128,)), 'y_true': torch.randint(2, size=(128,))},
id='AveragePrecision'),
pytest.param(ConfusionMatrix(), confusion_matrix,
{'y_pred': torch.randint(10, size=(128,)),
'y_true': torch.randint(10, size=(128,))}, id='ConfusionMatrix'),
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
id='ConfusionMatrix'),
pytest.param(F1(average='macro'), partial(f1_score, average='macro'),
{'y_pred': torch.randint(10, size=(128,)),
'y_true': torch.randint(10, size=(128,))}, id='F1'),
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
id='F1'),
pytest.param(FBeta(beta=0.5, average='macro'), partial(fbeta_score, beta=0.5, average='macro'),
{'y_pred': torch.randint(10, size=(128,)),
'y_true': torch.randint(10, size=(128,))}, id='FBeta'),
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
id='FBeta'),
pytest.param(Precision(average='macro'), partial(precision_score, average='macro'),
{'y_pred': torch.randint(10, size=(128,)),
'y_true': torch.randint(10, size=(128,))}, id='Precision'),
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
id='Precision'),
pytest.param(Recall(average='macro'), partial(recall_score, average='macro'),
{'y_pred': torch.randint(10, size=(128,)),
'y_true': torch.randint(10, size=(128,))}, id='Recall'),
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
id='Recall'),
pytest.param(PrecisionRecallCurve(), xy_only(precision_recall_curve),
{'probas_pred': torch.rand(size=(128,)),
'y_true': torch.randint(2, size=(128,))}, id='PrecisionRecallCurve'),
{'probas_pred': torch.rand(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
id='PrecisionRecallCurve'),
pytest.param(ROC(), xy_only(roc_curve),
{'y_score': torch.rand(size=(128,)),
'y_true': torch.randint(2, size=(128,))}, id='ROC'),
{'y_score': torch.rand(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
id='ROC'),
pytest.param(AUROC(), roc_auc_score,
{'y_score': torch.rand(size=(128,)),
'y_true': torch.randint(2, size=(128,))}, id='AUROC'),
{'y_score': torch.rand(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
id='AUROC'),
])
def test_sklearn_metric(metric_class, sklearn_func, inputs: dict):
numpy_inputs = apply_to_collection(