Native torch metrics (#1488)
* Create metric.py * Create utils.py * Create __init__.py * Create __init__.py * Create __init__.py * add tests for metric utils * add tests for metric utils * add docstrings for metrics utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add function to recursively apply other function to collection * add tests for this function * add tests for this function * add tests for this function * update test * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * update metric name * remove example docs * fix tests * fix tests * add metric tests * fix to tensor conversion * fix to tensor conversion * fix apply to collection * fix apply to collection * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove tests from init * remove tests from init * add missing type annotations * rename utils to convertors * rename utils to convertors * rename utils to convertors * rename utils to convertors * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * add doctest example * rename file and fix imports * rename file and fix imports * added parametrized test * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * rename apply_to_collection to apply_func * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * Add requested changes and add ellipsis for doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * forgot to push these files... * forgot to push these files... * add explicit check for dtype to convert to * add explicit check for dtype to convert to * fix ddp tests * fix ddp tests * fix ddp tests * remove explicit ddp destruction * remove explicit ddp destruction * New metric classes (#1326) * Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * add function to reduce tensors (similar to reduction in torch.nn) * add functionals of reduction metrics * add functionals of reduction metrics * add more metrics * pep8 fixes * rename * rename * add reduction tests * add first classification tests * bugfixes * bugfixes * add more unit tests * fix roc score metric * fix tests * solve tests * fix docs * Update CHANGELOG.md * remove binaries * solve changes from rebase * add eos * test auc independently * fix formatting * docs * docs * chlog * move * function descriptions * Add documentation to native metrics (#2144) * add docs * add docs * Apply suggestions from code review * formatting * add docs Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai> * Rename tests/metrics/test_classification.py to tests/metrics/functional/test_classification.py * Rename tests/metrics/test_reduction.py to tests/metrics/functional/test_reduction.py * Add module interface for classification metrics * add basic tests for classification metrics' module interface * pep8 * add additional converters * add additional base class * change baseclass for some metrics * update classification tests * update converter tests * update metric tests * Apply suggestions from code review * tests-params * tests-params * imports * pep8 * tests-params * formatting * fix test_metrics * typo * formatting * fix dice tests * fix decorator order * fix tests * seed * dice test * formatting * try freeze test * formatting * fix tests * try spawn * formatting * fix Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: J. Borovec <jirka.borovec@seznam.cz> Co-authored-by: Xavier Sumba <c.uent@hotmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Nicki Skafte <nugginea@gmail.com>
This commit is contained in:
parent
9df2b2090d
commit
3436d00230
|
@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126))
|
||||
- Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
|
||||
- Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
|
||||
- Added Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488))
|
||||
- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723))
|
||||
- Allow dataloaders without sampler field present ([#1907](https://github.com/PyTorchLightning/pytorch-lightning/pull/1907))
|
||||
- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` [(#1908)](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908)
|
||||
|
|
|
@ -23,8 +23,8 @@ inputs to and outputs from numpy as well as automated ddp syncing.
|
|||
|
||||
"""
|
||||
|
||||
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
|
||||
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
|
||||
from pytorch_lightning.metrics.sklearn import (
|
||||
SklearnMetric, Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
|
||||
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
|
||||
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
|
||||
|
|
|
@ -0,0 +1,652 @@
|
|||
from typing import Any, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.classification import (
|
||||
accuracy,
|
||||
confusion_matrix,
|
||||
precision_recall_curve,
|
||||
precision,
|
||||
recall,
|
||||
average_precision,
|
||||
auroc,
|
||||
fbeta_score,
|
||||
f1_score,
|
||||
roc,
|
||||
multiclass_roc,
|
||||
multiclass_precision_recall_curve,
|
||||
dice_score
|
||||
)
|
||||
from pytorch_lightning.metrics.metric import TensorMetric, TensorCollectionMetric
|
||||
|
||||
__all__ = [
|
||||
'Accuracy',
|
||||
'ConfusionMatrix',
|
||||
'PrecisionRecall',
|
||||
'Precision',
|
||||
'Recall',
|
||||
'AveragePrecision',
|
||||
'AUROC',
|
||||
'FBeta',
|
||||
'F1',
|
||||
'ROC',
|
||||
'MulticlassROC',
|
||||
'MulticlassPrecisionRecall',
|
||||
'DiceCoefficient'
|
||||
]
|
||||
|
||||
|
||||
class Accuracy(TensorMetric):
|
||||
"""
|
||||
Computes the accuracy classification score
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: number of classes
|
||||
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||
Available reduction methods:
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
"""
|
||||
super().__init__(name='accuracy',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
self.num_classes = num_classes
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: ground truth labels
|
||||
|
||||
Return:
|
||||
A Tensor with the classification score.
|
||||
"""
|
||||
return accuracy(pred=pred, target=target,
|
||||
num_classes=self.num_classes, reduction=self.reduction)
|
||||
|
||||
|
||||
class ConfusionMatrix(TensorMetric):
|
||||
"""
|
||||
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
|
||||
in group i that were predicted in group j.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
normalize: bool = False,
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
normalize: whether to compute a normalized confusion matrix
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
"""
|
||||
super().__init__(name='confusion_matrix',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
self.normalize = normalize
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: ground truth labels
|
||||
|
||||
Return:
|
||||
A Tensor with the confusion matrix.
|
||||
"""
|
||||
return confusion_matrix(pred=pred, target=target,
|
||||
normalize=self.normalize)
|
||||
|
||||
|
||||
class PrecisionRecall(TensorCollectionMetric):
|
||||
"""
|
||||
Computes the precision recall curve
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pos_label: int = 1,
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
pos_label: positive label indicator
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
"""
|
||||
super().__init__(name='precision_recall_curve',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
|
||||
self.pos_label = pos_label
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: groundtruth labels
|
||||
sample_weight: the weights per sample
|
||||
|
||||
Return:
|
||||
torch.Tensor: precision values
|
||||
torch.Tensor: recall values
|
||||
torch.Tensor: threshold values
|
||||
"""
|
||||
return precision_recall_curve(pred=pred, target=target,
|
||||
sample_weight=sample_weight,
|
||||
pos_label=self.pos_label)
|
||||
|
||||
|
||||
class Precision(TensorMetric):
|
||||
"""
|
||||
Computes the precision score
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: number of classes
|
||||
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||
Available reduction methods:
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
"""
|
||||
super().__init__(name='precision',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: ground truth labels
|
||||
|
||||
Return:
|
||||
A Tensor with the classification score.
|
||||
"""
|
||||
return precision(pred=pred, target=target,
|
||||
num_classes=self.num_classes,
|
||||
reduction=self.reduction)
|
||||
|
||||
|
||||
class Recall(TensorMetric):
|
||||
"""
|
||||
Computes the recall score
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: number of classes
|
||||
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||
Available reduction methods:
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
"""
|
||||
super().__init__(name='recall',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: ground truth labels
|
||||
|
||||
Return:
|
||||
A Tensor with the classification score.
|
||||
"""
|
||||
return recall(pred=pred,
|
||||
target=target,
|
||||
num_classes=self.num_classes,
|
||||
reduction=self.reduction)
|
||||
|
||||
|
||||
class AveragePrecision(TensorMetric):
|
||||
"""
|
||||
Computes the average precision score
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pos_label: int = 1,
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
pos_label: positive label indicator
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
"""
|
||||
super().__init__(name='AP',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
|
||||
self.pos_label = pos_label
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: groundtruth labels
|
||||
sample_weight: the weights per sample
|
||||
|
||||
Return:
|
||||
torch.Tensor: classification score
|
||||
"""
|
||||
return average_precision(pred=pred, target=target,
|
||||
sample_weight=sample_weight,
|
||||
pos_label=self.pos_label)
|
||||
|
||||
|
||||
class AUROC(TensorMetric):
|
||||
"""
|
||||
Computes the area under curve (AUC) of the receiver operator characteristic (ROC)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pos_label: int = 1,
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
pos_label: positive label indicator
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
"""
|
||||
super().__init__(name='auroc',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
|
||||
self.pos_label = pos_label
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: groundtruth labels
|
||||
sample_weight: the weights per sample
|
||||
|
||||
Return:
|
||||
torch.Tensor: classification score
|
||||
"""
|
||||
return auroc(pred=pred, target=target,
|
||||
sample_weight=sample_weight,
|
||||
pos_label=self.pos_label)
|
||||
|
||||
|
||||
class FBeta(TensorMetric):
|
||||
"""Computes the FBeta Score"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
beta: float,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
beta: determines the weight of recall in the combined score.
|
||||
num_classes: number of classes
|
||||
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||
Available reduction methods:
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
"""
|
||||
super().__init__(name='fbeta',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
|
||||
self.beta = beta
|
||||
self.num_classes = num_classes
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: groundtruth labels
|
||||
|
||||
Return:
|
||||
torch.Tensor: classification score
|
||||
"""
|
||||
return fbeta_score(pred=pred, target=target,
|
||||
beta=self.beta, num_classes=self.num_classes,
|
||||
reduction=self.reduction)
|
||||
|
||||
|
||||
class F1(TensorMetric):
|
||||
"""Computes the F1 score"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: number of classes
|
||||
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||
Available reduction methods:
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
"""
|
||||
super().__init__(name='f1',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: groundtruth labels
|
||||
|
||||
Return:
|
||||
torch.Tensor: classification score
|
||||
"""
|
||||
return f1_score(pred=pred, target=target,
|
||||
num_classes=self.num_classes,
|
||||
reduction=self.reduction)
|
||||
|
||||
|
||||
class ROC(TensorCollectionMetric):
|
||||
"""
|
||||
Computes the Receiver Operator Characteristic (ROC)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pos_label: int = 1,
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
pos_label: positive label indicator
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
"""
|
||||
super().__init__(name='roc',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
|
||||
self.pos_label = pos_label
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: groundtruth labels
|
||||
sample_weight: the weights per sample
|
||||
|
||||
Return:
|
||||
torch.Tensor: false positive rate
|
||||
torch.Tensor: true positive rate
|
||||
torch.Tensor: thresholds
|
||||
"""
|
||||
return roc(pred=pred, target=target,
|
||||
sample_weight=sample_weight,
|
||||
pos_label=self.pos_label)
|
||||
|
||||
|
||||
class MulticlassROC(TensorCollectionMetric):
|
||||
"""
|
||||
Computes the multiclass ROC
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: number of classes
|
||||
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||
Available reduction methods:
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
"""
|
||||
super().__init__(name='multiclass_roc',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(
|
||||
self, pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: groundtruth labels
|
||||
sample_weight: Weights for each sample defining the sample's impact on the score
|
||||
|
||||
Return:
|
||||
tuple: A tuple consisting of one tuple per class,
|
||||
holding false positive rate, true positive rate and thresholds
|
||||
"""
|
||||
return multiclass_roc(pred=pred,
|
||||
target=target,
|
||||
sample_weight=sample_weight,
|
||||
num_classes=self.num_classes)
|
||||
|
||||
|
||||
class MulticlassPrecisionRecall(TensorCollectionMetric):
|
||||
"""Computes the multiclass PR Curve"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: number of classes
|
||||
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||
Available reduction methods:
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
"""
|
||||
super().__init__(name='multiclass_precision_recall_curve',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: groundtruth labels
|
||||
sample_weight: Weights for each sample defining the sample's impact on the score
|
||||
|
||||
Return:
|
||||
tuple: A tuple consisting of one tuple per class,
|
||||
holding precision, recall and thresholds
|
||||
"""
|
||||
return multiclass_precision_recall_curve(pred=pred,
|
||||
target=target,
|
||||
sample_weight=sample_weight,
|
||||
num_classes=self.num_classes)
|
||||
|
||||
|
||||
class DiceCoefficient(TensorMetric):
|
||||
"""Computes the dice coefficient"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
include_background: bool = False,
|
||||
nan_score: float = 0.0, no_fg_score: float = 0.0,
|
||||
reduction: str = 'elementwise_mean',
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
include_background: whether to also compute dice for the background
|
||||
nan_score: score to return, if a NaN occurs during computation (denom zero)
|
||||
no_fg_score: score to return, if no foreground pixel was found in target
|
||||
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||
Available reduction methods:
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
"""
|
||||
super().__init__(name='dice',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
|
||||
self.include_background = include_background
|
||||
self.nan_score = nan_score
|
||||
self.no_fg_score = no_fg_score
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: groundtruth labels
|
||||
|
||||
Return:
|
||||
torch.Tensor: the calculated dice coefficient
|
||||
"""
|
||||
return dice_score(pred=pred,
|
||||
target=target,
|
||||
bg=self.include_background,
|
||||
nan_score=self.nan_score,
|
||||
no_fg_score=self.no_fg_score,
|
||||
reduction=self.reduction)
|
|
@ -4,7 +4,6 @@ conversion to/from :class:`numpy.ndarray` and :class:`torch.Tensor` as well as u
|
|||
sync tensors between different processes in a DDP scenario, when needed.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import numbers
|
||||
from typing import Union, Any, Callable, Optional
|
||||
|
||||
|
@ -18,12 +17,13 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection
|
|||
def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
|
||||
"""
|
||||
Decorator function to apply a function to all inputs of a function.
|
||||
|
||||
Args:
|
||||
func_to_apply: the function to apply to the inputs
|
||||
*dec_args: positional arguments for the function to be applied
|
||||
**dec_kwargs: keyword arguments for the function to be applied
|
||||
|
||||
Returns:
|
||||
Return:
|
||||
the decorated function
|
||||
"""
|
||||
|
||||
|
@ -42,12 +42,13 @@ def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callab
|
|||
def _apply_to_outputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
|
||||
"""
|
||||
Decorator function to apply a function to all outputs of a function.
|
||||
|
||||
Args:
|
||||
func_to_apply: the function to apply to the outputs
|
||||
*dec_args: positional arguments for the function to be applied
|
||||
**dec_kwargs: keyword arguments for the function to be applied
|
||||
|
||||
Returns:
|
||||
Return:
|
||||
the decorated function
|
||||
"""
|
||||
|
||||
|
@ -69,9 +70,8 @@ def _convert_to_tensor(data: Any) -> Any:
|
|||
Args:
|
||||
data: the data to convert to tensor
|
||||
|
||||
Returns:
|
||||
Return:
|
||||
the converted data
|
||||
|
||||
"""
|
||||
if isinstance(data, numbers.Number):
|
||||
return torch.tensor([data])
|
||||
|
@ -86,12 +86,12 @@ def _convert_to_tensor(data: Any) -> Any:
|
|||
|
||||
def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
|
||||
"""Convert all tensors and numpy arrays to numpy arrays.
|
||||
|
||||
Args:
|
||||
data: the tensor or array to convert to numpy
|
||||
|
||||
Returns:
|
||||
Return:
|
||||
the resulting numpy array
|
||||
|
||||
"""
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.cpu().detach().numpy()
|
||||
|
@ -103,6 +103,33 @@ def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) ->
|
|||
raise TypeError("The given type ('%s') cannot be converted to a numpy array!" % type(data).__name__)
|
||||
|
||||
|
||||
def _numpy_metric_input_conversion(func_to_decorate: Callable) -> Callable:
|
||||
"""
|
||||
Decorator converting all inputs of a function to numpy
|
||||
|
||||
Args:
|
||||
func_to_decorate: the function whose inputs shall be converted
|
||||
|
||||
Return:
|
||||
Callable: the decorated function
|
||||
"""
|
||||
return _apply_to_inputs(
|
||||
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate)
|
||||
|
||||
|
||||
def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable:
|
||||
"""
|
||||
Decorator converting all outputs of a function to tensors
|
||||
|
||||
Args:
|
||||
func_to_decorate: the function whose outputs shall be converted
|
||||
|
||||
Return:
|
||||
Callable: the decorated function
|
||||
"""
|
||||
return _apply_to_outputs(_convert_to_tensor)(func_to_decorate)
|
||||
|
||||
|
||||
def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable:
|
||||
"""
|
||||
Decorator handling the argument conversion for metrics working on numpy.
|
||||
|
@ -112,19 +139,45 @@ def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable:
|
|||
Args:
|
||||
func_to_decorate: the function whose inputs and outputs shall be converted
|
||||
|
||||
Returns:
|
||||
Return:
|
||||
the decorated function
|
||||
|
||||
"""
|
||||
# applies collection conversion from tensor to numpy to all inputs
|
||||
# we need to include numpy arrays here, since otherwise they will also be treated as sequences
|
||||
func_convert_inputs = _apply_to_inputs(
|
||||
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate)
|
||||
func_convert_inputs = _numpy_metric_input_conversion(func_to_decorate)
|
||||
# converts all inputs back to tensors (device doesn't matter here, since this is handled by BaseMetric)
|
||||
func_convert_in_out = _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)
|
||||
func_convert_in_out = _tensor_metric_output_conversion(func_convert_inputs)
|
||||
return func_convert_in_out
|
||||
|
||||
|
||||
def _tensor_metric_input_conversion(func_to_decorate: Callable) -> Callable:
|
||||
"""
|
||||
Decorator converting all inputs of a function to tensors
|
||||
|
||||
Args:
|
||||
func_to_decorate: the function whose inputs shall be converted
|
||||
|
||||
Return:
|
||||
Callable: the decorated function
|
||||
"""
|
||||
return _apply_to_inputs(
|
||||
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_tensor)(func_to_decorate)
|
||||
|
||||
|
||||
def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> Callable:
|
||||
"""
|
||||
Decorator converting all numpy arrays and numbers occuring in the outputs of a function to tensors
|
||||
|
||||
Args:
|
||||
func_to_decorate: the function whose outputs shall be converted
|
||||
|
||||
Return:
|
||||
Callable: the decorated function
|
||||
"""
|
||||
return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number),
|
||||
_convert_to_tensor)(func_to_decorate)
|
||||
|
||||
|
||||
def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
|
||||
"""
|
||||
Decorator Handling the argument conversion for metrics working on tensors.
|
||||
|
@ -133,16 +186,33 @@ def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
|
|||
Args:
|
||||
func_to_decorate: the function whose inputs and outputs shall be converted
|
||||
|
||||
Returns:
|
||||
Return:
|
||||
the decorated function
|
||||
|
||||
"""
|
||||
# converts all inputs to tensor if possible
|
||||
# we need to include tensors here, since otherwise they will also be treated as sequences
|
||||
func_convert_inputs = _apply_to_inputs(
|
||||
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_tensor)(func_to_decorate)
|
||||
func_convert_inputs = _tensor_metric_input_conversion(func_to_decorate)
|
||||
# convert all outputs to tensor if possible
|
||||
return _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)
|
||||
return _tensor_metric_output_conversion(func_convert_inputs)
|
||||
|
||||
|
||||
def _tensor_collection_metric_conversion(func_to_decorate: Callable) -> Callable:
|
||||
"""
|
||||
Decorator Handling the argument conversion for metrics working on tensors.
|
||||
All inputs of the decorated function and all numpy arrays and numbers in
|
||||
it's outputs will be converted to tensors
|
||||
|
||||
Args:
|
||||
func_to_decorate: the function whose inputs and outputs shall be converted
|
||||
|
||||
Return:
|
||||
the decorated function
|
||||
"""
|
||||
# converts all inputs to tensor if possible
|
||||
# we need to include tensors here, since otherwise they will also be treated as sequences
|
||||
func_convert_inputs = _tensor_metric_input_conversion(func_to_decorate)
|
||||
# convert all outputs to tensor if possible
|
||||
return _tensor_collection_metric_output_conversion(func_convert_inputs)
|
||||
|
||||
|
||||
def _sync_ddp_if_available(result: Union[torch.Tensor],
|
||||
|
@ -157,9 +227,8 @@ def _sync_ddp_if_available(result: Union[torch.Tensor],
|
|||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: the reduction operation. Defaults to sum.
|
||||
|
||||
Returns:
|
||||
Return:
|
||||
reduced value
|
||||
|
||||
"""
|
||||
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
|
@ -177,11 +246,32 @@ def _sync_ddp_if_available(result: Union[torch.Tensor],
|
|||
return result
|
||||
|
||||
|
||||
def sync_ddp(group: Optional[Any] = None,
|
||||
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
|
||||
"""
|
||||
This decorator syncs a functions outputs across different processes for DDP.
|
||||
|
||||
Args:
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: the reduction operation. Defaults to sum
|
||||
|
||||
Return:
|
||||
the decorated function
|
||||
|
||||
"""
|
||||
|
||||
def decorator_fn(func_to_decorate):
|
||||
return _apply_to_outputs(apply_to_collection, torch.Tensor,
|
||||
_sync_ddp_if_available, group=group,
|
||||
reduce_op=reduce_op)(func_to_decorate)
|
||||
|
||||
return decorator_fn
|
||||
|
||||
|
||||
def numpy_metric(group: Optional[Any] = None,
|
||||
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
|
||||
"""
|
||||
This decorator shall be used on all function metrics working on numpy arrays.
|
||||
|
||||
It handles the argument conversion and DDP reduction for metrics working on numpy.
|
||||
All inputs of the decorated function will be converted to numpy and all
|
||||
outputs will be converted to tensors.
|
||||
|
@ -191,15 +281,12 @@ def numpy_metric(group: Optional[Any] = None,
|
|||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: the reduction operation. Defaults to sum
|
||||
|
||||
Returns:
|
||||
Return:
|
||||
the decorated function
|
||||
|
||||
"""
|
||||
|
||||
def decorator_fn(func_to_decorate):
|
||||
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available,
|
||||
group=group,
|
||||
reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate))
|
||||
return sync_ddp(group=group, reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate))
|
||||
|
||||
return decorator_fn
|
||||
|
||||
|
@ -208,7 +295,6 @@ def tensor_metric(group: Optional[Any] = None,
|
|||
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
|
||||
"""
|
||||
This decorator shall be used on all function metrics working on tensors.
|
||||
|
||||
It handles the argument conversion and DDP reduction for metrics working on tensors.
|
||||
All inputs and outputs of the decorated function will be converted to tensors.
|
||||
In DDP Training all output tensors will be reduced according to the given rules.
|
||||
|
@ -217,14 +303,34 @@ def tensor_metric(group: Optional[Any] = None,
|
|||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: the reduction operation. Defaults to sum
|
||||
|
||||
Returns:
|
||||
Return:
|
||||
the decorated function
|
||||
|
||||
"""
|
||||
|
||||
def decorator_fn(func_to_decorate):
|
||||
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp_if_available,
|
||||
group=group,
|
||||
reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate))
|
||||
return sync_ddp(group=group, reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate))
|
||||
|
||||
return decorator_fn
|
||||
|
||||
|
||||
def tensor_collection_metric(group: Optional[Any] = None,
|
||||
reduce_op: Optional[torch.distributed.ReduceOp] = None) -> Callable:
|
||||
"""
|
||||
This decorator shall be used on all function metrics working on tensors and returning collections
|
||||
that cannot be converted to tensors.
|
||||
It handles the argument conversion and DDP reduction for metrics working on tensors.
|
||||
All inputs and outputs of the decorated function will be converted to tensors.
|
||||
In DDP Training all output tensors will be reduced according to the given rules.
|
||||
|
||||
Args:
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: the reduction operation. Defaults to sum
|
||||
|
||||
Return:
|
||||
the decorated function
|
||||
"""
|
||||
|
||||
def decorator_fn(func_to_decorate):
|
||||
return sync_ddp(group=group, reduce_op=reduce_op)(_tensor_collection_metric_conversion(func_to_decorate))
|
||||
|
||||
return decorator_fn
|
||||
|
|
|
@ -0,0 +1,693 @@
|
|||
from collections import Sequence
|
||||
from functools import wraps
|
||||
from typing import Optional, Tuple, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.reduction import reduce
|
||||
|
||||
|
||||
def to_onehot(
|
||||
tensor: torch.Tensor,
|
||||
n_classes: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts a dense label tensor to one-hot format
|
||||
|
||||
Args:
|
||||
tensor: dense label tensor, with shape [N, d1, d2, ...]
|
||||
|
||||
n_classes: number of classes C
|
||||
|
||||
Output:
|
||||
A sparse label tensor with shape [N, C, d1, d2, ...]
|
||||
"""
|
||||
if n_classes is None:
|
||||
n_classes = int(tensor.max().detach().item() + 1)
|
||||
dtype, device, shape = tensor.dtype, tensor.device, tensor.shape
|
||||
tensor_onehot = torch.zeros(shape[0], n_classes, *shape[1:],
|
||||
dtype=dtype, device=device)
|
||||
index = tensor.long().unsqueeze(1).expand_as(tensor_onehot)
|
||||
return tensor_onehot.scatter_(1, index, 1.0)
|
||||
|
||||
|
||||
def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
|
||||
"""
|
||||
Converts a tensor of probabilities to a dense label tensor
|
||||
|
||||
Args:
|
||||
tensor: probabilities to get the categorical label [N, d1, d2, ...]
|
||||
argmax_dim: dimension to apply (default: 1)
|
||||
|
||||
Return:
|
||||
A tensor with categorical labels [N, d2, ...]
|
||||
"""
|
||||
return torch.argmax(tensor, dim=argmax_dim)
|
||||
|
||||
|
||||
def get_num_classes(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: Optional[int],
|
||||
) -> int:
|
||||
"""
|
||||
Returns the number of classes for a given prediction and target tensor.
|
||||
|
||||
Args:
|
||||
pred: predicted values
|
||||
target: true labels
|
||||
num_classes: number of classes if known (default: None)
|
||||
|
||||
Return:
|
||||
An integer that represents the number of classes.
|
||||
"""
|
||||
if num_classes is None:
|
||||
if pred.ndim > target.ndim:
|
||||
num_classes = pred.size(1)
|
||||
else:
|
||||
num_classes = int(target.max().detach().item() + 1)
|
||||
return num_classes
|
||||
|
||||
|
||||
def stat_scores(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
class_index: int, argmax_dim: int = 1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Calculates the number of true positive, falsepositivee, true negative
|
||||
and false negative for a specific class
|
||||
|
||||
Args:
|
||||
pred: prediction tensor
|
||||
|
||||
target: target tensor
|
||||
|
||||
class_index: class to calculate over
|
||||
|
||||
argmax_dim: if pred is a tensor of probabilities, this indicates the
|
||||
axis the argmax transformation will be applied over
|
||||
|
||||
Return:
|
||||
Tensors in the following order: True Positive, False Positive, True Negative, False Negative
|
||||
|
||||
"""
|
||||
if pred.ndim == target.ndim + 1:
|
||||
pred = to_categorical(pred, argmax_dim=argmax_dim)
|
||||
|
||||
tp = ((pred == class_index) * (target == class_index)).to(torch.long).sum()
|
||||
fp = ((pred == class_index) * (target != class_index)).to(torch.long).sum()
|
||||
tn = ((pred != class_index) * (target != class_index)).to(torch.long).sum()
|
||||
fn = ((pred != class_index) * (target == class_index)).to(torch.long).sum()
|
||||
|
||||
return tp, fp, tn, fn
|
||||
|
||||
|
||||
def stat_scores_multiple_classes(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
argmax_dim: int = 1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Calls the stat_scores function iteratively for all classes, thus
|
||||
calculating the number of true postive, false postive, true negative
|
||||
and false negative for each class
|
||||
|
||||
Args:
|
||||
pred: prediction tensor
|
||||
target: target tensor
|
||||
class_index: class to calculate over
|
||||
argmax_dim: if pred is a tensor of probabilities, this indicates the
|
||||
axis the argmax transformation will be applied over
|
||||
|
||||
Return:
|
||||
Returns tensors for: tp, fp, tn, fn
|
||||
|
||||
"""
|
||||
num_classes = get_num_classes(pred=pred, target=target,
|
||||
num_classes=num_classes)
|
||||
|
||||
if pred.ndim == target.ndim + 1:
|
||||
pred = to_categorical(pred, argmax_dim=argmax_dim)
|
||||
|
||||
tps = torch.zeros((num_classes,), device=pred.device)
|
||||
fps = torch.zeros((num_classes,), device=pred.device)
|
||||
tns = torch.zeros((num_classes,), device=pred.device)
|
||||
fns = torch.zeros((num_classes,), device=pred.device)
|
||||
|
||||
for c in range(num_classes):
|
||||
tps[c], fps[c], tns[c], fns[c] = stat_scores(pred=pred, target=target,
|
||||
class_index=c)
|
||||
|
||||
return tps, fps, tns, fns
|
||||
|
||||
|
||||
def accuracy(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction='elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the accuracy classification score
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: ground truth labels
|
||||
num_classes: number of classes
|
||||
reduction: a method for reducing accuracies over labels (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
A Tensor with the classification score.
|
||||
"""
|
||||
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred, target=target,
|
||||
num_classes=num_classes)
|
||||
|
||||
if not (target > 0).any() and num_classes is None:
|
||||
raise RuntimeError("cannot infer num_classes when target is all zero")
|
||||
|
||||
accuracies = (tps + tns) / (tps + tns + fps + fns)
|
||||
|
||||
return reduce(accuracies, reduction=reduction)
|
||||
|
||||
|
||||
def confusion_matrix(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
normalize: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
|
||||
in group i that were predicted in group j.
|
||||
|
||||
Args:
|
||||
pred: estimated targets
|
||||
target: ground truth labels
|
||||
normalize: normalizes confusion matrix
|
||||
|
||||
Return:
|
||||
Tensor, confusion matrix C [num_classes, num_classes ]
|
||||
"""
|
||||
num_classes = get_num_classes(pred, target, None)
|
||||
|
||||
d = target.size(-1)
|
||||
batch_vec = torch.arange(target.size(-1))
|
||||
# this will account for multilabel
|
||||
unique_labels = batch_vec * num_classes ** 2 + target.view(-1) * num_classes + pred.view(-1)
|
||||
|
||||
bins = torch.bincount(unique_labels, minlength=d * num_classes ** 2)
|
||||
cm = bins.reshape(d, num_classes, num_classes).squeeze().float()
|
||||
|
||||
if normalize:
|
||||
cm = cm / cm.sum(-1)
|
||||
|
||||
return cm
|
||||
|
||||
|
||||
def precision_recall(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes precision and recall for different thresholds
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
num_classes: number of classes
|
||||
reduction: method for reducing precision-recall values (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
Tensor with precision and recall
|
||||
"""
|
||||
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred,
|
||||
target=target,
|
||||
num_classes=num_classes)
|
||||
|
||||
tps = tps.to(torch.float)
|
||||
fps = fps.to(torch.float)
|
||||
fns = fns.to(torch.float)
|
||||
|
||||
precision = tps / (tps + fps)
|
||||
recall = tps / (tps + fns)
|
||||
|
||||
precision = reduce(precision, reduction=reduction)
|
||||
recall = reduce(recall, reduction=reduction)
|
||||
return precision, recall
|
||||
|
||||
|
||||
def precision(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes precision score.
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
num_classes: number of classes
|
||||
reduction: method for reducing precision values (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
Tensor with precision.
|
||||
"""
|
||||
return precision_recall(pred=pred, target=target,
|
||||
num_classes=num_classes, reduction=reduction)[0]
|
||||
|
||||
|
||||
def recall(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes recall score.
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
num_classes: number of classes
|
||||
reduction: method for reducing recall values (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements
|
||||
|
||||
Return:
|
||||
Tensor with recall.
|
||||
"""
|
||||
return precision_recall(pred=pred, target=target,
|
||||
num_classes=num_classes, reduction=reduction)[1]
|
||||
|
||||
|
||||
def fbeta_score(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
beta: float,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the F-beta score which is a weighted harmonic mean of precision and recall.
|
||||
It ranges between 1 and 0, where 1 is perfect and the worst value is 0.
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
beta: weights recall when combining the score.
|
||||
beta < 1: more weight to precision.
|
||||
beta > 1 more weight to recall
|
||||
beta = 0: only precision
|
||||
beta -> inf: only recall
|
||||
num_classes: number of classes
|
||||
reduction: method for reducing F-score (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements.
|
||||
|
||||
Return:
|
||||
Tensor with the value of F-score. It is a value between 0-1.
|
||||
"""
|
||||
prec, rec = precision_recall(pred=pred, target=target,
|
||||
num_classes=num_classes,
|
||||
reduction='none')
|
||||
|
||||
nom = (1 + beta ** 2) * prec * rec
|
||||
denom = ((beta ** 2) * prec + rec)
|
||||
fbeta = nom / denom
|
||||
|
||||
return reduce(fbeta, reduction=reduction)
|
||||
|
||||
|
||||
def f1_score(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction='elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes F1-score a.k.a F-measure.
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
num_classes: number of classes
|
||||
reduction: method for reducing F1-score (default: takes the mean)
|
||||
Available reduction methods:
|
||||
|
||||
- elementwise_mean: takes the mean
|
||||
- none: pass array
|
||||
- sum: add elements.
|
||||
|
||||
Return:
|
||||
Tensor containing F1-score
|
||||
"""
|
||||
return fbeta_score(pred=pred, target=target, beta=1.,
|
||||
num_classes=num_classes, reduction=reduction)
|
||||
|
||||
|
||||
def _binary_clf_curve(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
pos_label: int = 1.,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py
|
||||
"""
|
||||
if sample_weight is not None and not isinstance(sample_weight, torch.Tensor):
|
||||
sample_weight = torch.tensor(sample_weight, device=pred.device, dtype=torch.float)
|
||||
|
||||
# remove class dimension if necessary
|
||||
if pred.ndim > target.ndim:
|
||||
pred = pred[:, 0]
|
||||
desc_score_indices = torch.argsort(pred, descending=True)
|
||||
|
||||
pred = pred[desc_score_indices]
|
||||
target = target[desc_score_indices]
|
||||
|
||||
if sample_weight is not None:
|
||||
weight = sample_weight[desc_score_indices]
|
||||
else:
|
||||
weight = 1.
|
||||
|
||||
# pred typically has many tied values. Here we extract
|
||||
# the indices associated with the distinct values. We also
|
||||
# concatenate a value for the end of the curve.
|
||||
distinct_value_indices = torch.where(pred[1:] - pred[:-1])[0]
|
||||
threshold_idxs = torch.cat([distinct_value_indices,
|
||||
torch.tensor([target.size(0) - 1])])
|
||||
|
||||
target = (target == pos_label).to(torch.long)
|
||||
tps = torch.cumsum(target * weight, dim=0)[threshold_idxs]
|
||||
|
||||
if sample_weight is not None:
|
||||
# express fps as a cumsum to ensure fps is increasing even in
|
||||
# the presence of floating point errors
|
||||
fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs]
|
||||
|
||||
else:
|
||||
fps = 1 + threshold_idxs - tps
|
||||
|
||||
return fps, tps, pred[threshold_idxs]
|
||||
|
||||
|
||||
def roc(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
pos_label: int = 1.,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
sample_weight: sample weights
|
||||
pos_label: the label for the positive class (default: 1)
|
||||
|
||||
Return:
|
||||
[Tensor, Tensor, Tensor]: false-positive rate (fpr), true-positive rate (tpr), thresholds
|
||||
"""
|
||||
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
|
||||
sample_weight=sample_weight,
|
||||
pos_label=pos_label)
|
||||
|
||||
# Add an extra threshold position
|
||||
# to make sure that the curve starts at (0, 0)
|
||||
tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
|
||||
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
|
||||
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])
|
||||
|
||||
if fps[-1] <= 0:
|
||||
raise ValueError("No negative samples in targets, false positive value should be meaningless")
|
||||
|
||||
fpr = fps / fps[-1]
|
||||
|
||||
if tps[-1] <= 0:
|
||||
raise ValueError("No positive samples in targets, true positive value should be meaningless")
|
||||
|
||||
tpr = tps / tps[-1]
|
||||
|
||||
return fpr, tpr, thresholds
|
||||
|
||||
|
||||
def multiclass_roc(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
num_classes: Optional[int] = None,
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Computes the Receiver Operating Characteristic (ROC) for multiclass predictors.
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
sample_weight: sample weights
|
||||
num_classes: number of classes (default: None, computes automatically from data)
|
||||
|
||||
Return:
|
||||
[num_classes, Tensor, Tensor, Tensor]: returns roc for each class.
|
||||
number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds
|
||||
"""
|
||||
num_classes = get_num_classes(pred, target, num_classes)
|
||||
|
||||
class_roc_vals = []
|
||||
for c in range(num_classes):
|
||||
pred_c = pred[:, c]
|
||||
|
||||
class_roc_vals.append(roc(pred=pred_c, target=target,
|
||||
sample_weight=sample_weight, pos_label=c))
|
||||
|
||||
return tuple(class_roc_vals)
|
||||
|
||||
|
||||
def precision_recall_curve(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
pos_label: int = 1.,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes precision-recall pairs for different thresholds.
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
sample_weight: sample weights
|
||||
pos_label: the label for the positive class (default: 1.)
|
||||
|
||||
Return:
|
||||
[Tensor, Tensor, Tensor]: precision, recall, thresholds
|
||||
"""
|
||||
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
|
||||
sample_weight=sample_weight,
|
||||
pos_label=pos_label)
|
||||
|
||||
precision = tps / (tps + fps)
|
||||
recall = tps / tps[-1]
|
||||
|
||||
# stop when full recall attained
|
||||
# and reverse the outputs so recall is decreasing
|
||||
last_ind = torch.where(tps == tps[-1])[0][0]
|
||||
sl = slice(0, last_ind.item() + 1)
|
||||
|
||||
# need to call reversed explicitly, since including that to slice would
|
||||
# introduce negative strides that are not yet supported in pytorch
|
||||
precision = torch.cat([reversed(precision[sl]),
|
||||
torch.ones(1, dtype=precision.dtype,
|
||||
device=precision.device)])
|
||||
|
||||
recall = torch.cat([reversed(recall[sl]),
|
||||
torch.zeros(1, dtype=recall.dtype,
|
||||
device=recall.device)])
|
||||
|
||||
thresholds = torch.tensor(reversed(thresholds[sl]))
|
||||
|
||||
return precision, recall, thresholds
|
||||
|
||||
|
||||
def multiclass_precision_recall_curve(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
num_classes: Optional[int] = None,
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Computes precision-recall pairs for different thresholds given a multiclass scores.
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
sample_weight: sample weight
|
||||
num_classes: number of classes
|
||||
|
||||
Return:
|
||||
[num_classes, Tensor, Tensor, Tensor]: number of classes, precision, recall, thresholds
|
||||
"""
|
||||
num_classes = get_num_classes(pred, target, num_classes)
|
||||
|
||||
class_pr_vals = []
|
||||
for c in range(num_classes):
|
||||
pred_c = pred[:, c]
|
||||
|
||||
class_pr_vals.append(precision_recall_curve(
|
||||
pred=pred_c,
|
||||
target=target,
|
||||
sample_weight=sample_weight, pos_label=c))
|
||||
|
||||
return tuple(class_pr_vals)
|
||||
|
||||
|
||||
def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True):
|
||||
"""
|
||||
Computes Area Under the Curve (AUC) using the trapezoidal rule
|
||||
|
||||
Args:
|
||||
x: x-coordinates
|
||||
y: y-coordinates
|
||||
reorder: reorder coordinates, so they are increasing.
|
||||
|
||||
Return:
|
||||
AUC score (float)
|
||||
"""
|
||||
direction = 1.
|
||||
|
||||
if reorder:
|
||||
# can't use lexsort here since it is not implemented for torch
|
||||
order = torch.argsort(x)
|
||||
x, y = x[order], y[order]
|
||||
else:
|
||||
dx = x[1:] - x[:-1]
|
||||
if (dx < 0).any():
|
||||
if (dx, 0).all():
|
||||
direction = -1.
|
||||
else:
|
||||
raise ValueError("Reordering is not turned on, and "
|
||||
"the x array is not increasing: %s" % x)
|
||||
|
||||
return direction * torch.trapz(y, x)
|
||||
|
||||
|
||||
def auc_decorator(reorder: bool = True) -> Callable:
|
||||
def wrapper(func_to_decorate: Callable) -> Callable:
|
||||
@wraps(func_to_decorate)
|
||||
def new_func(*args, **kwargs) -> torch.Tensor:
|
||||
x, y = func_to_decorate(*args, **kwargs)[:2]
|
||||
|
||||
return auc(x, y, reorder=reorder)
|
||||
|
||||
return new_func
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def multiclass_auc_decorator(reorder: bool = True) -> Callable:
|
||||
def wrapper(func_to_decorate: Callable) -> Callable:
|
||||
def new_func(*args, **kwargs) -> torch.Tensor:
|
||||
results = []
|
||||
for class_result in func_to_decorate(*args, **kwargs):
|
||||
x, y = class_result[:2]
|
||||
results.append(auc(x, y, reorder=reorder))
|
||||
|
||||
return torch.cat(results)
|
||||
|
||||
return new_func
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def auroc(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
pos_label: int = 1.,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
sample_weight: sample weights
|
||||
pos_label: the label for the positive class (default: 1.)
|
||||
"""
|
||||
|
||||
@auc_decorator(reorder=True)
|
||||
def _auroc(pred, target, sample_weight, pos_label):
|
||||
return roc(pred, target, sample_weight, pos_label)
|
||||
|
||||
return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label)
|
||||
|
||||
|
||||
def average_precision(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
pos_label: int = 1.,
|
||||
) -> torch.Tensor:
|
||||
precision, recall, _ = precision_recall_curve(pred=pred, target=target,
|
||||
sample_weight=sample_weight,
|
||||
pos_label=pos_label)
|
||||
# Return the step function integral
|
||||
# The following works because the last entry of precision is
|
||||
# guaranteed to be 1, as returned by precision_recall_curve
|
||||
return -torch.sum(recall[1:] - recall[:-1] * precision[:-1])
|
||||
|
||||
|
||||
def dice_score(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
bg: bool = False,
|
||||
nan_score: float = 0.0,
|
||||
no_fg_score: float = 0.0,
|
||||
reduction: str = 'elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
n_classes = pred.shape[1]
|
||||
bg = (1 - int(bool(bg)))
|
||||
scores = torch.zeros(n_classes - bg, device=pred.device, dtype=torch.float32)
|
||||
for i in range(bg, n_classes):
|
||||
if not (target == i).any():
|
||||
# no foreground class
|
||||
scores[i - bg] += no_fg_score
|
||||
continue
|
||||
|
||||
tp, fp, tn, fn = stat_scores(pred=pred, target=target, class_index=i)
|
||||
|
||||
denom = (2 * tp + fp + fn).to(torch.float)
|
||||
|
||||
if torch.isclose(denom, torch.zeros_like(denom)).any():
|
||||
# nan result
|
||||
score_cls = nan_score
|
||||
else:
|
||||
score_cls = (2 * tp).to(torch.float) / denom
|
||||
|
||||
scores[i - bg] += score_cls
|
||||
return reduce(scores, reduction=reduction)
|
|
@ -0,0 +1,24 @@
|
|||
import torch
|
||||
|
||||
|
||||
def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
|
||||
"""
|
||||
Reduces a given tensor by a given reduction method
|
||||
|
||||
Args:
|
||||
to_reduce : the tensor, which shall be reduced
|
||||
reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum')
|
||||
|
||||
Return:
|
||||
reduced Tensor
|
||||
|
||||
Raise:
|
||||
ValueError if an invalid reduction parameter was given
|
||||
"""
|
||||
if reduction == 'elementwise_mean':
|
||||
return torch.mean(to_reduce)
|
||||
if reduction == 'none':
|
||||
return to_reduce
|
||||
if reduction == 'sum':
|
||||
return torch.sum(to_reduce)
|
||||
raise ValueError('Reduction parameter unknown.')
|
|
@ -3,16 +3,16 @@ from typing import Any, Optional
|
|||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from torch.nn import Module
|
||||
|
||||
from pytorch_lightning.metrics.converters import tensor_metric, numpy_metric
|
||||
from pytorch_lightning.metrics.converters import (
|
||||
tensor_metric, numpy_metric, tensor_collection_metric)
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
|
||||
|
||||
__all__ = ['Metric', 'TensorMetric', 'NumpyMetric']
|
||||
|
||||
|
||||
class Metric(ABC, DeviceDtypeModuleMixin, Module):
|
||||
class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC):
|
||||
"""
|
||||
Abstract base class for metric implementation.
|
||||
|
||||
|
@ -20,6 +20,7 @@ class Metric(ABC, DeviceDtypeModuleMixin, Module):
|
|||
1. Return multiple Outputs
|
||||
2. Handle their own DDP sync
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
"""
|
||||
Args:
|
||||
|
@ -49,6 +50,7 @@ class TensorMetric(Metric):
|
|||
All inputs and outputs will be casted to tensors if necessary.
|
||||
Already handles DDP sync and input/output conversions.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str,
|
||||
reduce_group: Optional[Any] = None,
|
||||
reduce_op: Optional[Any] = None):
|
||||
|
@ -73,6 +75,47 @@ class TensorMetric(Metric):
|
|||
_to_device_dtype)
|
||||
|
||||
|
||||
class TensorCollectionMetric(Metric):
|
||||
"""
|
||||
Base class for metric implementation operating directly on tensors.
|
||||
All inputs will be casted to tensors if necessary. Outputs won't be casted.
|
||||
Already handles DDP sync and input conversions.
|
||||
|
||||
This class differs from :class:`TensorMetric`, as it assumes all outputs to
|
||||
be collections of tensors and does not explicitly convert them. This is
|
||||
necessary, since some collections (like for ROC, Precision-Recall Curve etc.)
|
||||
cannot be converted to tensors at the highest level.
|
||||
All numpy arrays and numbers occuring in these outputs will still be converted.
|
||||
|
||||
Use this class as a baseclass, whenever you want to ensure inputs are
|
||||
tensors and outputs cannot be converted to tensors automatically
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str,
|
||||
reduce_group: Optional[Any] = None,
|
||||
reduce_op: Optional[Any] = None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
name: the metric's name
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||
Defaults to sum.
|
||||
"""
|
||||
super().__init__(name)
|
||||
self._orig_call = tensor_collection_metric(group=reduce_group,
|
||||
reduce_op=reduce_op)(super().__call__)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> torch.Tensor:
|
||||
def _to_device_dtype(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.to(device=self.device, dtype=self.dtype, non_blocking=True)
|
||||
|
||||
return apply_to_collection(self._orig_call(*args, **kwargs), torch.Tensor,
|
||||
_to_device_dtype)
|
||||
|
||||
|
||||
class NumpyMetric(Metric):
|
||||
"""
|
||||
Base class for metric implementation operating on numpy arrays.
|
||||
|
@ -80,6 +123,7 @@ class NumpyMetric(Metric):
|
|||
be casted to tensors if necessary.
|
||||
Already handles DDP sync and input/output conversions.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str,
|
||||
reduce_group: Optional[Any] = None,
|
||||
reduce_op: Optional[Any] = None):
|
||||
|
|
|
@ -1,130 +0,0 @@
|
|||
import numbers
|
||||
from typing import Union, Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data._utils.collate import default_convert
|
||||
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
||||
|
||||
def _apply_to_inputs(func_to_apply, *dec_args, **dec_kwargs):
|
||||
def decorator_fn(func_to_decorate):
|
||||
def new_func(*args, **kwargs):
|
||||
args = func_to_apply(args, *dec_args, **dec_kwargs)
|
||||
kwargs = func_to_apply(kwargs, *dec_args, **dec_kwargs)
|
||||
return func_to_decorate(*args, **kwargs)
|
||||
|
||||
return new_func
|
||||
|
||||
return decorator_fn
|
||||
|
||||
|
||||
def _apply_to_outputs(func_to_apply, *dec_args, **dec_kwargs):
|
||||
def decorator_fn(function_to_decorate):
|
||||
def new_func(*args, **kwargs):
|
||||
result = function_to_decorate(*args, **kwargs)
|
||||
return func_to_apply(result, *dec_args, **dec_kwargs)
|
||||
|
||||
return new_func
|
||||
|
||||
return decorator_fn
|
||||
|
||||
|
||||
def _convert_to_tensor(data: Any) -> Any:
|
||||
"""
|
||||
Maps all kind of collections and numbers to tensors
|
||||
|
||||
Args:
|
||||
data: the data to convert to tensor
|
||||
|
||||
Returns:
|
||||
the converted data
|
||||
|
||||
"""
|
||||
if isinstance(data, numbers.Number):
|
||||
return torch.tensor([data])
|
||||
else:
|
||||
return default_convert(data)
|
||||
|
||||
|
||||
def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
|
||||
"""
|
||||
converts all tensors and numpy arrays to numpy arrays
|
||||
Args:
|
||||
data: the tensor or array to convert to numpy
|
||||
|
||||
Returns:
|
||||
the resulting numpy array
|
||||
|
||||
"""
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.cpu().detach().numpy()
|
||||
elif isinstance(data, numbers.Number):
|
||||
return np.array([data])
|
||||
return data
|
||||
|
||||
|
||||
def _numpy_metric_conversion(func_to_decorate):
|
||||
# Applies collection conversion from tensor to numpy to all inputs
|
||||
# we need to include numpy arrays here, since otherwise they will also be treated as sequences
|
||||
func_convert_inputs = _apply_to_inputs(
|
||||
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate)
|
||||
# converts all inputs back to tensors (device doesn't matter here, since this is handled by BaseMetric)
|
||||
func_convert_in_out = _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)
|
||||
return func_convert_in_out
|
||||
|
||||
|
||||
def _tensor_metric_conversion(func_to_decorate):
|
||||
# Converts all inputs to tensor if possible
|
||||
func_convert_inputs = _apply_to_inputs(_convert_to_tensor)(func_to_decorate)
|
||||
# convert all outputs to tensor if possible
|
||||
return _apply_to_outputs(_convert_to_tensor)(func_convert_inputs)
|
||||
|
||||
|
||||
def _sync_ddp(result: Union[torch.Tensor],
|
||||
group: Any = torch.distributed.group.WORLD,
|
||||
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Function to reduce the tensors from several ddp processes to one master process
|
||||
|
||||
Args:
|
||||
result: the value to sync and reduce (typically tensor or number)
|
||||
device: the device to put the synced and reduced value to
|
||||
dtype: the datatype to convert the synced and reduced value to
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: the reduction operation. Defaults to sum
|
||||
|
||||
Returns:
|
||||
reduced value
|
||||
|
||||
"""
|
||||
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
# sync all processes before reduction
|
||||
torch.distributed.barrier(group=group)
|
||||
torch.distributed.all_reduce(result, op=reduce_op, group=group,
|
||||
async_op=False)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def numpy_metric(group: Any = torch.distributed.group.WORLD,
|
||||
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM):
|
||||
def decorator_fn(func_to_decorate):
|
||||
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp,
|
||||
group=group,
|
||||
reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate))
|
||||
|
||||
return decorator_fn
|
||||
|
||||
|
||||
def tensor_metric(group: Any = torch.distributed.group.WORLD,
|
||||
reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM):
|
||||
def decorator_fn(func_to_decorate):
|
||||
return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp,
|
||||
group=group,
|
||||
reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate))
|
||||
|
||||
return decorator_fn
|
|
@ -0,0 +1,309 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.metrics.functional.classification import (
|
||||
to_onehot,
|
||||
to_categorical,
|
||||
get_num_classes,
|
||||
stat_scores,
|
||||
stat_scores_multiple_classes,
|
||||
accuracy,
|
||||
confusion_matrix,
|
||||
precision,
|
||||
recall,
|
||||
fbeta_score,
|
||||
f1_score,
|
||||
_binary_clf_curve,
|
||||
dice_score,
|
||||
average_precision,
|
||||
auroc,
|
||||
precision_recall_curve,
|
||||
roc,
|
||||
auc,
|
||||
)
|
||||
|
||||
|
||||
def test_onehot():
|
||||
test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
||||
expected = torch.tensor([
|
||||
[
|
||||
[1, 0, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0],
|
||||
[0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]
|
||||
], [
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0],
|
||||
[0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 1]
|
||||
]
|
||||
])
|
||||
|
||||
assert test_tensor.shape == (2, 5)
|
||||
assert expected.shape == (2, 10, 5)
|
||||
|
||||
onehot_classes = to_onehot(test_tensor, n_classes=10)
|
||||
onehot_no_classes = to_onehot(test_tensor)
|
||||
|
||||
assert torch.allclose(onehot_classes, onehot_no_classes)
|
||||
|
||||
assert onehot_classes.shape == expected.shape
|
||||
assert onehot_no_classes.shape == expected.shape
|
||||
|
||||
assert torch.allclose(expected.to(onehot_no_classes), onehot_no_classes)
|
||||
assert torch.allclose(expected.to(onehot_classes), onehot_classes)
|
||||
|
||||
|
||||
def test_to_categorical():
|
||||
test_tensor = torch.tensor([
|
||||
[
|
||||
[1, 0, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0],
|
||||
[0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0]
|
||||
], [
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0],
|
||||
[0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 1]
|
||||
]
|
||||
]).to(torch.float)
|
||||
|
||||
expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
||||
assert expected.shape == (2, 5)
|
||||
assert test_tensor.shape == (2, 10, 5)
|
||||
|
||||
result = to_categorical(test_tensor)
|
||||
|
||||
assert result.shape == expected.shape
|
||||
assert torch.allclose(result, expected.to(result.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'num_classes', 'expected_num_classes'], [
|
||||
pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10),
|
||||
pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
|
||||
pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
|
||||
])
|
||||
def test_get_num_classes(pred, target, num_classes, expected_num_classes):
|
||||
assert get_num_classes(pred, target, num_classes) == expected_num_classes
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', 'expected_tn', 'expected_fn'], [
|
||||
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1),
|
||||
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1)
|
||||
])
|
||||
def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn):
|
||||
tp, fp, tn, fn = stat_scores(pred, target, class_index=4)
|
||||
|
||||
assert tp.item() == expected_tp
|
||||
assert fp.item() == expected_fp
|
||||
assert tn.item() == expected_tn
|
||||
assert fn.item() == expected_fn
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', 'expected_tn', 'expected_fn'], [
|
||||
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]),
|
||||
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1]),
|
||||
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]),
|
||||
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1])
|
||||
])
|
||||
def test_stat_scores_multiclass(pred, target, expected_tp, expected_fp, expected_tn, expected_fn):
|
||||
tp, fp, tn, fn = stat_scores_multiple_classes(pred, target)
|
||||
|
||||
assert torch.allclose(torch.tensor(expected_tp).to(tp), tp)
|
||||
assert torch.allclose(torch.tensor(expected_fp).to(fp), fp)
|
||||
assert torch.allclose(torch.tensor(expected_tn).to(tn), tn)
|
||||
assert torch.allclose(torch.tensor(expected_fn).to(fn), fn)
|
||||
|
||||
|
||||
def test_multilabel_accuracy():
|
||||
# Dense label indicator matrix format
|
||||
y1 = torch.tensor([[0, 1, 1], [1, 0, 1]])
|
||||
y2 = torch.tensor([[0, 0, 1], [1, 0, 1]])
|
||||
|
||||
assert torch.allclose(accuracy(y1, y2, reduction='none'), torch.tensor([0.8333333134651184] * 2))
|
||||
assert torch.allclose(accuracy(y1, y1, reduction='none'), torch.tensor([1., 1.]))
|
||||
assert torch.allclose(accuracy(y2, y2, reduction='none'), torch.tensor([1., 1.]))
|
||||
assert torch.allclose(accuracy(y2, torch.logical_not(y2), reduction='none'), torch.tensor([0., 0.]))
|
||||
assert torch.allclose(accuracy(y1, torch.logical_not(y1), reduction='none'), torch.tensor([0., 0.]))
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
accuracy(y2, torch.zeros_like(y2), reduction='none')
|
||||
|
||||
|
||||
def test_confusion_matrix():
|
||||
target = (torch.arange(120) % 3).view(-1, 1)
|
||||
pred = target.clone()
|
||||
cm = confusion_matrix(pred, target, normalize=True)
|
||||
|
||||
assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]))
|
||||
|
||||
pred = torch.zeros_like(pred)
|
||||
cm = confusion_matrix(pred, target, normalize=True)
|
||||
assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]]))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_prec', 'expected_rec'], [
|
||||
pytest.param(torch.tensor([1., 0., 1., 0.]), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]),
|
||||
pytest.param(to_onehot(torch.tensor([1., 0., 1., 0.])), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5])
|
||||
])
|
||||
def test_precision_recall(pred, target, expected_prec, expected_rec):
|
||||
prec = precision(pred, target, reduction='none')
|
||||
rec = recall(pred, target, reduction='none')
|
||||
|
||||
assert torch.allclose(torch.tensor(expected_prec).to(prec), prec)
|
||||
assert torch.allclose(torch.tensor(expected_rec).to(rec), rec)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'beta', 'exp_score'], [
|
||||
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 0.5, [0.5, 0.5]),
|
||||
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 1, [0.5, 0.5]),
|
||||
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 2, [0.5, 0.5]),
|
||||
])
|
||||
def test_fbeta_score(pred, target, beta, exp_score):
|
||||
score = fbeta_score(torch.tensor(pred), torch.tensor(target), beta, reduction='none')
|
||||
assert torch.allclose(score, torch.tensor(exp_score))
|
||||
|
||||
score = fbeta_score(to_onehot(torch.tensor(pred)), torch.tensor(target), beta, reduction='none')
|
||||
assert torch.allclose(score, torch.tensor(exp_score))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'exp_score'], [
|
||||
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], [0.5, 0.5]),
|
||||
])
|
||||
def test_f1_score(pred, target, exp_score):
|
||||
score = f1_score(torch.tensor(pred), torch.tensor(target), reduction='none')
|
||||
assert torch.allclose(score, torch.tensor(exp_score))
|
||||
|
||||
score = f1_score(to_onehot(torch.tensor(pred)), torch.tensor(target), reduction='none')
|
||||
assert torch.allclose(score, torch.tensor(exp_score))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [
|
||||
pytest.param(1, 1., 42),
|
||||
pytest.param(None, 1., 42),
|
||||
])
|
||||
def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
|
||||
# TODO: move back the pred and target to test func arguments
|
||||
# if you fix the array inside the function, you'd also have fix the shape,
|
||||
# because when the array changes, you also have to fix the shape
|
||||
seed_everything(0)
|
||||
pred = torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100
|
||||
target = torch.tensor([0, 1] * 50, dtype=torch.int)
|
||||
if sample_weight is not None:
|
||||
sample_weight = torch.ones_like(pred) * sample_weight
|
||||
|
||||
fps, tps, thresh = _binary_clf_curve(pred, target, sample_weight, pos_label)
|
||||
|
||||
assert isinstance(tps, torch.Tensor)
|
||||
assert isinstance(fps, torch.Tensor)
|
||||
assert isinstance(thresh, torch.Tensor)
|
||||
assert tps.shape == (exp_shape,)
|
||||
assert fps.shape == (exp_shape,)
|
||||
assert thresh.shape == (exp_shape,)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], [
|
||||
pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4])
|
||||
])
|
||||
def test_pr_curve(pred, target, expected_p, expected_r, expected_t):
|
||||
p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target))
|
||||
assert p.size() == r.size()
|
||||
assert p.size(0) == t.size(0) + 1
|
||||
|
||||
assert torch.allclose(p, torch.tensor(expected_p).to(p))
|
||||
assert torch.allclose(r, torch.tensor(expected_r).to(r))
|
||||
assert torch.allclose(t, torch.tensor(expected_t).to(t))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [
|
||||
pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]),
|
||||
pytest.param([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]),
|
||||
pytest.param([1, 1], [1, 0], [0, 1], [0, 1]),
|
||||
pytest.param([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]),
|
||||
pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]),
|
||||
])
|
||||
def test_roc_curve(pred, target, expected_tpr, expected_fpr):
|
||||
fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target))
|
||||
|
||||
assert fpr.shape == tpr.shape
|
||||
assert fpr.size(0) == thresh.size(0)
|
||||
assert torch.allclose(fpr, torch.tensor(expected_fpr).to(fpr))
|
||||
assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||
pytest.param([0, 0, 1, 1], [0, 0, 1, 1], 1.),
|
||||
pytest.param([1, 1, 0, 0], [0, 0, 1, 1], 0.),
|
||||
pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.5),
|
||||
pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 1.),
|
||||
pytest.param([0.5, 0.5, 0.5, 0.5], [1, 1, 0, 0], 0.5),
|
||||
])
|
||||
def test_auroc(pred, target, expected):
|
||||
score = auroc(torch.tensor(pred), torch.tensor(target)).item()
|
||||
assert score == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['x', 'y', 'expected'], [
|
||||
pytest.param([0, 1], [0, 1], 0.5),
|
||||
pytest.param([1, 0], [0, 1], 0.5),
|
||||
pytest.param([1, 0, 0], [0, 1, 1], 0.5),
|
||||
pytest.param([0, 1], [1, 1], 1),
|
||||
pytest.param([0, 0.5, 1], [0, 0.5, 1], 0.5),
|
||||
])
|
||||
def test_auc(x, y, expected):
|
||||
# Test Area Under Curve (AUC) computation
|
||||
assert auc(torch.tensor(x), torch.tensor(y)) == expected
|
||||
|
||||
|
||||
def test_average_precision_constant_values():
|
||||
# Check the average_precision_score of a constant predictor is
|
||||
# the TPR
|
||||
|
||||
# Generate a dataset with 25% of positives
|
||||
target = torch.zeros(100, dtype=torch.float)
|
||||
target[::4] = 1
|
||||
# And a constant score
|
||||
pred = torch.ones(100)
|
||||
# The precision is then the fraction of positive whatever the recall
|
||||
# is, as there is only one threshold:
|
||||
assert average_precision(pred, target).item() == .25
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||
pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.),
|
||||
pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.),
|
||||
pytest.param([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3),
|
||||
pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.),
|
||||
])
|
||||
def test_dice_score(pred, target, expected):
|
||||
score = dice_score(torch.tensor(pred), torch.tensor(target))
|
||||
assert score == expected
|
||||
|
||||
# example data taken from
|
||||
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
|
|
@ -0,0 +1,15 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.reduction import reduce
|
||||
|
||||
|
||||
def test_reduce():
|
||||
start_tensor = torch.rand(50, 40, 30)
|
||||
|
||||
assert torch.allclose(reduce(start_tensor, 'elementwise_mean'), torch.mean(start_tensor))
|
||||
assert torch.allclose(reduce(start_tensor, 'sum'), torch.sum(start_tensor))
|
||||
assert torch.allclose(reduce(start_tensor, 'none'), start_tensor)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
reduce(start_tensor, 'error_reduction')
|
|
@ -0,0 +1,227 @@
|
|||
# NOTE: This file only tests if modules with arguments are running fine.
|
||||
# The actual metric implementation is tested in functional/test_classification.py
|
||||
# Especially reduction and reducing across processes won't be tested here!
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.classification import (
|
||||
Accuracy,
|
||||
ConfusionMatrix,
|
||||
PrecisionRecall,
|
||||
Precision,
|
||||
Recall,
|
||||
AveragePrecision,
|
||||
AUROC,
|
||||
FBeta,
|
||||
F1,
|
||||
ROC,
|
||||
MulticlassROC,
|
||||
MulticlassPrecisionRecall,
|
||||
DiceCoefficient,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random():
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_classes', [1, None])
|
||||
def test_accuracy(num_classes):
|
||||
acc = Accuracy(num_classes=num_classes)
|
||||
|
||||
assert acc.name == 'accuracy'
|
||||
|
||||
result = acc(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]),
|
||||
target=torch.tensor([[0, 0, 1], [1, 0, 1]]))
|
||||
|
||||
assert isinstance(result, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('normalize', [False, True])
|
||||
def test_confusion_matrix(normalize):
|
||||
conf_matrix = ConfusionMatrix(normalize=normalize)
|
||||
assert conf_matrix.name == 'confusion_matrix'
|
||||
|
||||
target = (torch.arange(120) % 3).view(-1, 1)
|
||||
pred = target.clone()
|
||||
|
||||
cm = conf_matrix(pred, target)
|
||||
|
||||
assert isinstance(cm, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('pos_label', [1, 2.])
|
||||
def test_precision_recall(pos_label):
|
||||
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1])
|
||||
|
||||
pr_curve = PrecisionRecall(pos_label=pos_label)
|
||||
assert pr_curve.name == 'precision_recall_curve'
|
||||
|
||||
pr = pr_curve(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
|
||||
|
||||
assert isinstance(pr, tuple)
|
||||
assert len(pr) == 3
|
||||
for tmp in pr:
|
||||
assert isinstance(tmp, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_classes', [1, None])
|
||||
def test_precision(num_classes):
|
||||
precision = Precision(num_classes=num_classes)
|
||||
|
||||
assert precision.name == 'precision'
|
||||
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1])
|
||||
|
||||
prec = precision(pred=pred, target=target)
|
||||
|
||||
assert isinstance(prec, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_classes', [1, None])
|
||||
def test_recall(num_classes):
|
||||
recall = Recall(num_classes=num_classes)
|
||||
|
||||
assert recall.name == 'recall'
|
||||
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1])
|
||||
|
||||
rec = recall(pred=pred, target=target)
|
||||
|
||||
assert isinstance(rec, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('pos_label', [1, 2])
|
||||
def test_average_precision(pos_label):
|
||||
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 0, 1])
|
||||
|
||||
avg_prec = AveragePrecision(pos_label=pos_label)
|
||||
assert avg_prec.name == 'AP'
|
||||
|
||||
ap = avg_prec(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
|
||||
|
||||
assert isinstance(ap, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('pos_label', [1, 2])
|
||||
def test_auroc(pos_label):
|
||||
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 0, 1])
|
||||
|
||||
auroc = AUROC(pos_label=pos_label)
|
||||
assert auroc.name == 'auroc'
|
||||
|
||||
area = auroc(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
|
||||
|
||||
assert isinstance(area, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['beta', 'num_classes'], [
|
||||
pytest.param(0., 1),
|
||||
pytest.param(0.5, 1),
|
||||
pytest.param(1., 1),
|
||||
pytest.param(2., 1),
|
||||
pytest.param(0., None),
|
||||
pytest.param(0.5, None),
|
||||
pytest.param(1., None),
|
||||
pytest.param(2., None)
|
||||
])
|
||||
def test_fbeta(beta, num_classes):
|
||||
fbeta = FBeta(beta=beta, num_classes=num_classes)
|
||||
assert fbeta.name == 'fbeta'
|
||||
|
||||
score = fbeta(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]),
|
||||
target=torch.tensor([[0, 0, 1], [1, 0, 1]]))
|
||||
|
||||
assert isinstance(score, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_classes', [1, None])
|
||||
def test_f1(num_classes):
|
||||
f1 = F1(num_classes=num_classes)
|
||||
assert f1.name == 'f1'
|
||||
|
||||
score = f1(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]),
|
||||
target=torch.tensor([[0, 0, 1], [1, 0, 1]]))
|
||||
|
||||
assert isinstance(score, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('pos_label', [1, 2])
|
||||
def test_roc(pos_label):
|
||||
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 4, 3])
|
||||
|
||||
roc = ROC(pos_label=pos_label)
|
||||
assert roc.name == 'roc'
|
||||
|
||||
res = roc(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
|
||||
|
||||
assert isinstance(res, tuple)
|
||||
assert len(res) == 3
|
||||
for tmp in res:
|
||||
assert isinstance(tmp, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_classes', [4, None])
|
||||
def test_multiclass_roc(num_classes):
|
||||
pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||
[0.05, 0.85, 0.05, 0.05],
|
||||
[0.05, 0.05, 0.85, 0.05],
|
||||
[0.05, 0.05, 0.05, 0.85]])
|
||||
target = torch.tensor([0, 1, 3, 2])
|
||||
|
||||
multi_roc = MulticlassROC(num_classes=num_classes)
|
||||
|
||||
assert multi_roc.name == 'multiclass_roc'
|
||||
|
||||
res = multi_roc(pred, target)
|
||||
|
||||
assert isinstance(res, tuple)
|
||||
|
||||
if num_classes is not None:
|
||||
assert len(res) == num_classes
|
||||
|
||||
for tmp in res:
|
||||
assert isinstance(tmp, tuple)
|
||||
assert len(tmp) == 3
|
||||
|
||||
for _tmp in tmp:
|
||||
assert isinstance(_tmp, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_classes', [4, None])
|
||||
def test_multiclass_pr(num_classes):
|
||||
pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||
[0.05, 0.85, 0.05, 0.05],
|
||||
[0.05, 0.05, 0.85, 0.05],
|
||||
[0.05, 0.05, 0.05, 0.85]])
|
||||
target = torch.tensor([0, 1, 3, 2])
|
||||
|
||||
multi_pr = MulticlassPrecisionRecall(num_classes=num_classes)
|
||||
|
||||
assert multi_pr.name == 'multiclass_precision_recall_curve'
|
||||
|
||||
pr = multi_pr(pred, target)
|
||||
|
||||
assert isinstance(pr, tuple)
|
||||
|
||||
if num_classes is not None:
|
||||
assert len(pr) == num_classes
|
||||
|
||||
for tmp in pr:
|
||||
assert isinstance(tmp, tuple)
|
||||
assert len(tmp) == 3
|
||||
|
||||
for _tmp in tmp:
|
||||
assert isinstance(_tmp, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('include_background', [True, False])
|
||||
def test_dice_coefficient(include_background):
|
||||
dice_coeff = DiceCoefficient(include_background=include_background)
|
||||
|
||||
assert dice_coeff.name == 'dice'
|
||||
|
||||
dice = dice_coeff(torch.randint(0, 1, (10, 25, 25)),
|
||||
torch.randint(0, 1, (10, 25, 25)))
|
||||
|
||||
assert isinstance(dice, torch.Tensor)
|
|
@ -6,16 +6,19 @@ import torch.multiprocessing as mp
|
|||
|
||||
import tests.base.utils as tutils
|
||||
from pytorch_lightning.metrics.converters import (
|
||||
_apply_to_inputs, _apply_to_outputs, _convert_to_tensor, _convert_to_numpy,
|
||||
_numpy_metric_conversion, _tensor_metric_conversion, _sync_ddp_if_available, tensor_metric, numpy_metric)
|
||||
_apply_to_inputs,
|
||||
_apply_to_outputs,
|
||||
_convert_to_tensor,
|
||||
_convert_to_numpy,
|
||||
_numpy_metric_conversion,
|
||||
_tensor_metric_conversion,
|
||||
_sync_ddp_if_available,
|
||||
tensor_metric,
|
||||
numpy_metric
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['args', 'kwargs'],
|
||||
[pytest.param([], {}),
|
||||
pytest.param([1., 2.], {}),
|
||||
pytest.param([], {'a': 1., 'b': 2.}),
|
||||
pytest.param([1., 2.], {'a': 1., 'b': 2.})])
|
||||
def test_apply_to_inputs(args, kwargs):
|
||||
def test_apply_to_inputs():
|
||||
def apply_fn(inputs, factor):
|
||||
if isinstance(inputs, (float, int)):
|
||||
return inputs * factor
|
||||
|
@ -25,22 +28,24 @@ def test_apply_to_inputs(args, kwargs):
|
|||
return [apply_fn(x, factor) for x in inputs]
|
||||
|
||||
@_apply_to_inputs(apply_fn, factor=2.)
|
||||
def test_fn(*func_args, **func_kwargs):
|
||||
return func_args, func_kwargs
|
||||
def test_fn(*args, **kwargs):
|
||||
return args, kwargs
|
||||
|
||||
result_args, result_kwargs = test_fn(*args, **kwargs)
|
||||
assert isinstance(result_args, (list, tuple))
|
||||
assert isinstance(result_kwargs, dict)
|
||||
assert len(result_args) == len(args)
|
||||
assert len(result_kwargs) == len(kwargs)
|
||||
assert all([k in result_kwargs for k in kwargs.keys()])
|
||||
for arg, result_arg in zip(args, result_args):
|
||||
assert arg * 2. == result_arg
|
||||
for args in [[], [1., 2.]]:
|
||||
for kwargs in [{}, {'a': 1., 'b': 2.}]:
|
||||
result_args, result_kwargs = test_fn(*args, **kwargs)
|
||||
assert isinstance(result_args, (list, tuple))
|
||||
assert isinstance(result_kwargs, dict)
|
||||
assert len(result_args) == len(args)
|
||||
assert len(result_kwargs) == len(kwargs)
|
||||
assert all([k in result_kwargs for k in kwargs.keys()])
|
||||
for arg, result_arg in zip(args, result_args):
|
||||
assert arg * 2. == result_arg
|
||||
|
||||
for key in kwargs.keys():
|
||||
arg = kwargs[key]
|
||||
result_arg = result_kwargs[key]
|
||||
assert arg * 2. == result_arg
|
||||
for key in kwargs.keys():
|
||||
arg = kwargs[key]
|
||||
result_arg = result_kwargs[key]
|
||||
assert arg * 2. == result_arg
|
||||
|
||||
|
||||
def test_apply_to_outputs():
|
||||
|
@ -100,7 +105,7 @@ def test_tensor_metric_conversion():
|
|||
assert result.item() == 5.
|
||||
|
||||
|
||||
def setup_ddp(rank, worldsize, ):
|
||||
def _setup_ddp(rank, worldsize):
|
||||
import os
|
||||
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
|
@ -109,8 +114,8 @@ def setup_ddp(rank, worldsize, ):
|
|||
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
|
||||
|
||||
|
||||
def ddp_test_fn(rank, worldsize):
|
||||
setup_ddp(rank, worldsize)
|
||||
def _ddp_test_fn(rank, worldsize):
|
||||
_setup_ddp(rank, worldsize)
|
||||
tensor = torch.tensor([1.], device='cuda:0')
|
||||
|
||||
reduced_tensor = _sync_ddp_if_available(tensor)
|
||||
|
@ -119,6 +124,7 @@ def ddp_test_fn(rank, worldsize):
|
|||
'Sync-Reduce does not work properly with DDP and Tensors'
|
||||
|
||||
|
||||
@pytest.mark.spawn
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
def test_sync_reduce_ddp():
|
||||
"""Make sure sync-reduce works with DDP"""
|
||||
|
@ -126,7 +132,9 @@ def test_sync_reduce_ddp():
|
|||
tutils.set_random_master_port()
|
||||
|
||||
worldsize = 2
|
||||
mp.spawn(ddp_test_fn, args=(worldsize,), nprocs=worldsize)
|
||||
mp.spawn(_ddp_test_fn, args=(worldsize,), nprocs=worldsize)
|
||||
|
||||
# dist.destroy_process_group()
|
||||
|
||||
|
||||
def test_sync_reduce_simple():
|
||||
|
@ -161,16 +169,18 @@ def _test_tensor_metric(is_ddp: bool):
|
|||
|
||||
|
||||
def _ddp_test_tensor_metric(rank, worldsize):
|
||||
setup_ddp(rank, worldsize)
|
||||
_setup_ddp(rank, worldsize)
|
||||
_test_tensor_metric(True)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
def test_tensor_metric_ddp():
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
|
||||
world_size = 2
|
||||
mp.spawn(_ddp_test_tensor_metric, args=(world_size,), nprocs=world_size)
|
||||
# dist.destroy_process_group()
|
||||
|
||||
|
||||
def test_tensor_metric_simple():
|
||||
|
@ -199,16 +209,19 @@ def _test_numpy_metric(is_ddp: bool):
|
|||
|
||||
|
||||
def _ddp_test_numpy_metric(rank, worldsize):
|
||||
setup_ddp(rank, worldsize)
|
||||
_setup_ddp(rank, worldsize)
|
||||
_test_numpy_metric(True)
|
||||
|
||||
|
||||
@pytest.mark.spawn
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
def test_numpy_metric_ddp():
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
world_size = 2
|
||||
mp.spawn(_ddp_test_numpy_metric, args=(world_size,), nprocs=world_size)
|
||||
# dist.destroy_process_group()
|
||||
|
||||
|
||||
def test_numpy_metric_simple():
|
||||
_test_tensor_metric(False)
|
||||
_test_numpy_metric(False)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
|
||||
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric, TensorCollectionMetric
|
||||
|
||||
|
||||
class DummyTensorMetric(TensorMetric):
|
||||
|
@ -24,7 +24,65 @@ class DummyNumpyMetric(NumpyMetric):
|
|||
return 1.
|
||||
|
||||
|
||||
class DummyTensorCollectionMetric(TensorCollectionMetric):
|
||||
def __init__(self):
|
||||
super().__init__('dummy')
|
||||
|
||||
def forward(self, input1, input2):
|
||||
assert isinstance(input1, torch.Tensor)
|
||||
assert isinstance(input2, torch.Tensor)
|
||||
return 1., 2., 3., 4.
|
||||
|
||||
|
||||
def _test_collection_metric(metric: Metric):
|
||||
""" Test that metric.device, metric.dtype works for metric collection """
|
||||
input1, input2 = torch.tensor([1.]), torch.tensor([2.])
|
||||
|
||||
def change_and_check_device_dtype(device, dtype):
|
||||
metric.to(device=device, dtype=dtype)
|
||||
|
||||
metric_val = metric(input1, input2)
|
||||
assert not isinstance(metric_val, torch.Tensor)
|
||||
|
||||
if device is not None:
|
||||
assert metric.device in [device, torch.device(device)]
|
||||
|
||||
if dtype is not None:
|
||||
assert metric.dtype == dtype
|
||||
|
||||
devices = [None, 'cpu']
|
||||
if torch.cuda.is_available():
|
||||
devices += ['cuda:0']
|
||||
|
||||
for device in devices:
|
||||
for dtype in [None, torch.float32, torch.float64]:
|
||||
change_and_check_device_dtype(device=device, dtype=dtype)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
metric.cuda(0)
|
||||
assert metric.device == torch.device('cuda', index=0)
|
||||
|
||||
metric.cpu()
|
||||
assert metric.device == torch.device('cpu')
|
||||
|
||||
metric.type(torch.int8)
|
||||
assert metric.dtype == torch.int8
|
||||
|
||||
metric.float()
|
||||
assert metric.dtype == torch.float32
|
||||
|
||||
metric.double()
|
||||
assert metric.dtype == torch.float64
|
||||
assert all(out.dtype == torch.float64 for out in metric(input1, input2))
|
||||
|
||||
if torch.cuda.is_available():
|
||||
metric.cuda()
|
||||
metric.half()
|
||||
assert metric.dtype == torch.float16
|
||||
|
||||
|
||||
def _test_metric(metric: Metric):
|
||||
""" Test that metric.device, metric.dtype works for single metric"""
|
||||
input1, input2 = torch.tensor([1.]), torch.tensor([2.])
|
||||
|
||||
def change_and_check_device_dtype(device, dtype):
|
||||
|
@ -83,3 +141,7 @@ def test_tensor_metric():
|
|||
|
||||
def test_numpy_metric():
|
||||
_test_metric(DummyNumpyMetric())
|
||||
|
||||
|
||||
def test_tensor_collection():
|
||||
_test_collection_metric(DummyTensorCollectionMetric())
|
||||
|
|
|
@ -5,13 +5,24 @@ from functools import partial
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import (accuracy_score, average_precision_score, auc, confusion_matrix, f1_score,
|
||||
fbeta_score, precision_score, recall_score, precision_recall_curve, roc_curve,
|
||||
roc_auc_score)
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
average_precision_score,
|
||||
auc,
|
||||
confusion_matrix,
|
||||
f1_score,
|
||||
fbeta_score,
|
||||
precision_score,
|
||||
recall_score,
|
||||
precision_recall_curve,
|
||||
roc_curve,
|
||||
roc_auc_score
|
||||
)
|
||||
|
||||
from pytorch_lightning.metrics.converters import _convert_to_numpy
|
||||
from pytorch_lightning.metrics.sklearn import (Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
|
||||
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
|
||||
from pytorch_lightning.metrics.sklearn import (
|
||||
Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
|
||||
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
||||
|
||||
|
@ -25,37 +36,38 @@ def xy_only(func):
|
|||
@pytest.mark.parametrize(['metric_class', 'sklearn_func', 'inputs'], [
|
||||
pytest.param(Accuracy(), accuracy_score,
|
||||
{'y_pred': torch.randint(low=0, high=10, size=(128,)),
|
||||
'y_true': torch.randint(low=0, high=10, size=(128,))}, id='Accuracy'),
|
||||
'y_true': torch.randint(low=0, high=10, size=(128,))},
|
||||
id='Accuracy'),
|
||||
pytest.param(AUC(), auc, {'x': torch.arange(10, dtype=torch.float) / 10,
|
||||
'y': torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2,
|
||||
0.2, 0.3, 0.5, 0.6, 0.7])}, id='AUC'),
|
||||
'y': torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.3, 0.5, 0.6, 0.7])},
|
||||
id='AUC'),
|
||||
pytest.param(AveragePrecision(), average_precision_score,
|
||||
{'y_score': torch.randint(2, size=(128,)),
|
||||
'y_true': torch.randint(2, size=(128,))}, id='AveragePrecision'),
|
||||
{'y_score': torch.randint(2, size=(128,)), 'y_true': torch.randint(2, size=(128,))},
|
||||
id='AveragePrecision'),
|
||||
pytest.param(ConfusionMatrix(), confusion_matrix,
|
||||
{'y_pred': torch.randint(10, size=(128,)),
|
||||
'y_true': torch.randint(10, size=(128,))}, id='ConfusionMatrix'),
|
||||
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
|
||||
id='ConfusionMatrix'),
|
||||
pytest.param(F1(average='macro'), partial(f1_score, average='macro'),
|
||||
{'y_pred': torch.randint(10, size=(128,)),
|
||||
'y_true': torch.randint(10, size=(128,))}, id='F1'),
|
||||
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
|
||||
id='F1'),
|
||||
pytest.param(FBeta(beta=0.5, average='macro'), partial(fbeta_score, beta=0.5, average='macro'),
|
||||
{'y_pred': torch.randint(10, size=(128,)),
|
||||
'y_true': torch.randint(10, size=(128,))}, id='FBeta'),
|
||||
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
|
||||
id='FBeta'),
|
||||
pytest.param(Precision(average='macro'), partial(precision_score, average='macro'),
|
||||
{'y_pred': torch.randint(10, size=(128,)),
|
||||
'y_true': torch.randint(10, size=(128,))}, id='Precision'),
|
||||
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
|
||||
id='Precision'),
|
||||
pytest.param(Recall(average='macro'), partial(recall_score, average='macro'),
|
||||
{'y_pred': torch.randint(10, size=(128,)),
|
||||
'y_true': torch.randint(10, size=(128,))}, id='Recall'),
|
||||
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
|
||||
id='Recall'),
|
||||
pytest.param(PrecisionRecallCurve(), xy_only(precision_recall_curve),
|
||||
{'probas_pred': torch.rand(size=(128,)),
|
||||
'y_true': torch.randint(2, size=(128,))}, id='PrecisionRecallCurve'),
|
||||
{'probas_pred': torch.rand(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
|
||||
id='PrecisionRecallCurve'),
|
||||
pytest.param(ROC(), xy_only(roc_curve),
|
||||
{'y_score': torch.rand(size=(128,)),
|
||||
'y_true': torch.randint(2, size=(128,))}, id='ROC'),
|
||||
{'y_score': torch.rand(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
|
||||
id='ROC'),
|
||||
pytest.param(AUROC(), roc_auc_score,
|
||||
{'y_score': torch.rand(size=(128,)),
|
||||
'y_true': torch.randint(2, size=(128,))}, id='AUROC'),
|
||||
{'y_score': torch.rand(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
|
||||
id='AUROC'),
|
||||
])
|
||||
def test_sklearn_metric(metric_class, sklearn_func, inputs: dict):
|
||||
numpy_inputs = apply_to_collection(
|
||||
|
|
Loading…
Reference in New Issue