Metrics docs (#2184)

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* Apply suggestions from code review

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* add workers fix

* add workers fix

* add workers fix

* add workers fix

* add workers fix

* add workers fix

* add workers fix

* add workers fix

* add workers fix

* add workers fix

* Update docs/source/metrics.rst

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Update docs/source/metrics.rst

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Update docs/source/metrics.rst

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* Update docs/source/metrics.rst

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* add workers fix

* add workers fix

* add workers fix

* doctests

* add workers fix

* add workers fix

* fixes

* fix docs

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* fixes

* Apply suggestions from code review

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* add workers fix

* Update docs/source/metrics.rst

Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>

* doctests

* add workers fix

* fix docs

* fixes

* fixes

* fix doctests

* Apply suggestions from code review

* fix doctests

* fix examples

* bug

* Update docs/source/metrics.rst

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update docs/source/metrics.rst

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update docs/source/metrics.rst

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* fixes

* fixes

* fixes

* fixes

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Jirka <jirka@pytorchlightning.ai>
Co-authored-by: Nicki Skafte <nugginea@gmail.com>
This commit is contained in:
William Falcon 2020-06-16 07:42:56 -04:00 committed by GitHub
parent e289e45120
commit 55fbcc00f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 696 additions and 44 deletions

View File

@ -1,4 +1,318 @@
.. automodule:: pytorch_lightning.metrics
:members:
:noindex:
:exclude-members:
.. testsetup:: *
from torch.nn import Module
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.metrics import TensorMetric, NumpyMetric
Metrics
=======
This is a general package for PyTorch Metrics. These can also be used with regular non-lightning PyTorch code.
Metrics are used to monitor model performance.
In this package we provide two major pieces of functionality.
1. A Metric class you can use to implement metrics with built-in distributed (ddp) support which are device agnostic.
2. A collection of popular metrics already implemented for you.
Example::
from pytorch_lightning.metrics.functional import accuracy
pred = torch.tensor([0, 1, 2, 3])
target = torch.tensor([0, 1, 2, 2])
# calculates accuracy across all GPUs and all Nodes used in training
accuracy(pred, target)
Out::
tensor(0.7500)
--------------
Implement a metric
------------------
You can implement metrics as either a PyTorch metric or a Numpy metric. Numpy metrics
will slow down training, use PyTorch metrics when possible.
Use :class:`TensorMetric` to implement native PyTorch metrics. This class
handles automated DDP syncing and converts all inputs and outputs to tensors.
Use :class:`NumpyMetric` to implement numpy metrics. This class
handles automated DDP syncing and converts all inputs and outputs to tensors.
.. warning::
Numpy metrics might slow down your training substantially,
since every metric computation requires a GPU sync to convert tensors to numpy.
TensorMetric
^^^^^^^^^^^^
Here's an example showing how to implement a TensorMetric
.. testcode::
class RMSE(TensorMetric):
def forward(self, x, y):
return torch.sqrt(torch.mean(torch.pow(x-y, 2.0)))
.. autoclass:: pytorch_lightning.metrics.metric.TensorMetric
:noindex:
NumpyMetric
^^^^^^^^^^^
Here's an example showing how to implement a NumpyMetric
.. testcode::
class RMSE(NumpyMetric):
def forward(self, x, y):
return np.sqrt(np.mean(np.power(x-y, 2.0)))
.. autoclass:: pytorch_lightning.metrics.metric.NumpyMetric
:noindex:
--------------
Class Metrics
-------------
The following are metrics which can be instantiated as part of a module definition (even with just
plain PyTorch).
.. testcode::
from pytorch_lightning.metrics import Accuracy
# Plain PyTorch
class MyModule(Module):
def __init__(self):
super().__init__()
self.metric = Accuracy()
def forward(self, x, y):
y_hat = ...
acc = self.metric(y_hat, y)
# PyTorch Lightning
class MyModule(LightningModule):
def __init__(self):
super().__init__()
self.metric = Accuracy()
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = ...
acc = self.metric(y_hat, y)
These metrics even work when using distributed training:
.. code-block:: python
model = MyModule()
trainer = Trainer(gpus=8, num_nodes=2)
# any metric automatically reduces across GPUs (even the ones you implement using Lightning)
trainer.fit(model)
Accuracy
^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.Accuracy
:noindex:
AveragePrecision
^^^^^^^^^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.AveragePrecision
:noindex:
AUROC
^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.AUROC
:noindex:
ConfusionMatrix
^^^^^^^^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix
:noindex:
DiceCoefficient
^^^^^^^^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.DiceCoefficient
:noindex:
F1
^^
.. autoclass:: pytorch_lightning.metrics.classification.F1
:noindex:
FBeta
^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.FBeta
:noindex:
PrecisionRecall
^^^^^^^^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecall
:noindex:
Precision
^^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.Precision
:noindex:
Recall
^^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.Recall
:noindex:
ROC
^^^
.. autoclass:: pytorch_lightning.metrics.classification.ROC
:noindex:
MulticlassROC
^^^^^^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.MulticlassROC
:noindex:
MulticlassPrecisionRecall
^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.MulticlassPrecisionRecall
:noindex:
--------------
Functional Metrics
------------------
accuracy (F)
^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.accuracy
:noindex:
auc (F)
^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.auc
:noindex:
auroc (F)
^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.auroc
:noindex:
average_precision (F)
^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.average_precision
:noindex:
confusion_matrix (F)
^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.confusion_matrix
:noindex:
dice_score (F)
^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.dice_score
:noindex:
f1_score (F)
^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.f1_score
:noindex:
fbeta_score (F)
^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.fbeta_score
:noindex:
multiclass_precision_recall_curve (F)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.multiclass_precision_recall_curve
:noindex:
multiclass_roc (F)
^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.multiclass_roc
:noindex:
precision (F)
^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.precision
:noindex:
precision_recall (F)
^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.precision_recall
:noindex:
precision_recall_curve (F)
^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.precision_recall_curve
:noindex:
recall (F)
^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.recall
:noindex:
roc (F)
^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.roc
:noindex:
stat_scores (F)
^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.stat_scores
:noindex:
stat_scores_multiple_classes (F)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.stat_scores_multiple_classes
:noindex:
----------------
Metric pre-processing
---------------------
Metric
to_categorical (F)
^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.to_categorical
:noindex:
to_onehot (F)
^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.to_onehot
:noindex:

View File

@ -1,30 +1,15 @@
"""
Metrics
=======
Metrics are generally used to monitor model performance.
The following package aims to provide the most convenient ones as well
as a structure to implement your custom metrics for all the fancy research
you want to do.
For native PyTorch implementations of metrics, it is recommended to use
the :class:`TensorMetric` which handles automated DDP syncing and conversions
to tensors for all inputs and outputs.
If your metrics implementation works on numpy, just use the
:class:`NumpyMetric`, which handles the automated conversion of
inputs to and outputs from numpy as well as automated ddp syncing.
.. warning:: Employing numpy in your metric calculation might slow
down your training substantially, since every metric computation
requires a GPU sync to convert tensors to numpy.
"""
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
from pytorch_lightning.metrics.sklearn import (
SklearnMetric, Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
SklearnMetric,
Accuracy,
AveragePrecision,
AUC,
ConfusionMatrix,
F1,
FBeta,
Precision,
Recall,
PrecisionRecallCurve,
ROC,
AUROC)

View File

@ -60,6 +60,14 @@ class Accuracy(TensorMetric):
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = Accuracy()
>>> metric(pred, target)
tensor(0.7500)
"""
super().__init__(name='accuracy',
reduce_group=reduce_group,
@ -100,6 +108,17 @@ class ConfusionMatrix(TensorMetric):
normalize: whether to compute a normalized confusion matrix
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
Example:
>>> pred = torch.tensor([0, 1, 2, 2])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = ConfusionMatrix()
>>> metric(pred, target)
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 2.]])
"""
super().__init__(name='confusion_matrix',
reduce_group=reduce_group,
@ -138,6 +157,19 @@ class PrecisionRecall(TensorCollectionMetric):
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = PrecisionRecall()
>>> prec, recall, thr = metric(pred, target)
>>> prec
tensor([0.3333, 0.0000, 0.0000, 1.0000])
>>> recall
tensor([1., 0., 0., 0.])
>>> thr
tensor([1., 2., 3.])
"""
super().__init__(name='precision_recall_curve',
reduce_group=reduce_group,
@ -192,11 +224,18 @@ class Precision(TensorMetric):
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = Precision()
>>> metric(pred, target)
tensor(1.)
"""
super().__init__(name='precision',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.num_classes = num_classes
self.reduction = reduction
@ -239,6 +278,14 @@ class Recall(TensorMetric):
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = Recall()
>>> metric(pred, target)
tensor(0.8333)
"""
super().__init__(name='recall',
reduce_group=reduce_group,
@ -281,6 +328,14 @@ class AveragePrecision(TensorMetric):
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = AveragePrecision()
>>> metric(pred, target)
tensor(0.3333)
"""
super().__init__(name='AP',
reduce_group=reduce_group,
@ -327,6 +382,14 @@ class AUROC(TensorMetric):
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = AUROC()
>>> metric(pred, target)
tensor(0.3333)
"""
super().__init__(name='auroc',
reduce_group=reduce_group,
@ -379,6 +442,14 @@ class FBeta(TensorMetric):
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = FBeta(0.25)
>>> metric(pred, target)
tensor(0.9815)
"""
super().__init__(name='fbeta',
reduce_group=reduce_group,
@ -425,6 +496,14 @@ class F1(TensorMetric):
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = F1()
>>> metric(pred, target)
tensor(0.8889)
"""
super().__init__(name='f1',
reduce_group=reduce_group,
@ -466,6 +545,19 @@ class ROC(TensorCollectionMetric):
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = ROC()
>>> fp, tp, thresholds = metric(pred, target)
>>> fp
tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000])
>>> tp
tensor([0., 0., 0., 1., 1.])
>>> thresholds
tensor([4., 3., 2., 1., 0.])
"""
super().__init__(name='roc',
reduce_group=reduce_group,
@ -519,6 +611,20 @@ class MulticlassROC(TensorCollectionMetric):
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
Example:
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> metric = MulticlassROC()
>>> classes_roc = metric(pred, target)
>>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE
((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])))
"""
super().__init__(name='multiclass_roc',
reduce_group=reduce_group,
@ -535,7 +641,7 @@ class MulticlassROC(TensorCollectionMetric):
Actual metric computation
Args:
pred: predicted labels
pred: predicted probability for each label
target: groundtruth labels
sample_weight: Weights for each sample defining the sample's impact on the score
@ -569,6 +675,21 @@ class MulticlassPrecisionRecall(TensorCollectionMetric):
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
Example:
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> metric = MulticlassPrecisionRecall()
>>> classes_pr = metric(pred, target)
>>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE
((tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])),
(tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])),
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])),
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])))
"""
super().__init__(name='multiclass_precision_recall_curve',
reduce_group=reduce_group,
@ -586,7 +707,7 @@ class MulticlassPrecisionRecall(TensorCollectionMetric):
Actual metric computation
Args:
pred: predicted labels
pred: predicted probability for each label
target: groundtruth labels
sample_weight: Weights for each sample defining the sample's impact on the score
@ -623,6 +744,20 @@ class DiceCoefficient(TensorMetric):
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
Example:
.. testcode:
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> metric = DiceCoefficient()
>>> classes_pr = metric(pred, target)
>>> metric(pred, target)
tensor(0.3333)
"""
super().__init__(name='dice',
reduce_group=reduce_group,
@ -638,7 +773,7 @@ class DiceCoefficient(TensorMetric):
Actual metric computation
Args:
pred: predicted labels
pred: predicted probability for each label
target: groundtruth labels
Return:

View File

@ -0,0 +1,21 @@
from pytorch_lightning.metrics.functional.classification import (
accuracy,
auc,
auroc,
average_precision,
confusion_matrix,
dice_score,
f1_score,
fbeta_score,
multiclass_precision_recall_curve,
multiclass_roc,
precision,
precision_recall,
precision_recall_curve,
recall,
roc,
stat_scores,
stat_scores_multiple_classes,
to_categorical,
to_onehot
)

View File

@ -21,6 +21,15 @@ def to_onehot(
Output:
A sparse label tensor with shape [N, C, d1, d2, ...]
Example:
>>> x = torch.tensor([1, 2, 3])
>>> to_onehot(x)
tensor([[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]])
"""
if n_classes is None:
n_classes = int(tensor.max().detach().item() + 1)
@ -41,6 +50,13 @@ def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
Return:
A tensor with categorical labels [N, d2, ...]
Example:
>>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]])
>>> to_categorical(x)
tensor([1, 0])
"""
return torch.argmax(tensor, dim=argmax_dim)
@ -65,7 +81,8 @@ def get_num_classes(
if pred.ndim > target.ndim:
num_classes = pred.size(1)
else:
num_classes = int(target.max().detach().item() + 1)
num_target_classes = int(target.max().detach().item() + 1)
num_classes = num_target_classes
return num_classes
@ -88,6 +105,18 @@ def stat_scores(
Return:
Tensors in the following order: True Positive, False Positive, True Negative, False Negative
Example:
>>> x = torch.tensor([1, 2, 3])
>>> y = torch.tensor([0, 2, 3])
>>> tp, fp, tn, fn, sup = stat_scores(x, y, class_index=1)
>>> stat_scores(x, y, class_index=1) # doctest: +NORMALIZE_WHITESPACE
(tensor(0),
tensor(1),
tensor(2),
tensor(0),
tensor(0))
"""
if pred.ndim == target.ndim + 1:
pred = to_categorical(pred, argmax_dim=argmax_dim)
@ -122,6 +151,17 @@ def stat_scores_multiple_classes(
Return:
Returns tensors for: tp, fp, tn, fn
Example:
>>> x = torch.tensor([1, 2, 3])
>>> y = torch.tensor([0, 2, 3])
>>> tps, fps, tns, fns, sups = stat_scores_multiple_classes(x, y)
>>> stat_scores_multiple_classes(x, y) # doctest: +NORMALIZE_WHITESPACE
(tensor([0., 0., 1., 1.]),
tensor([0., 1., 0., 0.]),
tensor([2., 2., 2., 2.]),
tensor([1., 0., 0., 0.]),
tensor([1., 0., 1., 1.]))
"""
num_classes = get_num_classes(pred=pred, target=target,
num_classes=num_classes)
@ -135,9 +175,7 @@ def stat_scores_multiple_classes(
fns = torch.zeros((num_classes,), device=pred.device)
sups = torch.zeros((num_classes,), device=pred.device)
for c in range(num_classes):
tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred,
target=target,
class_index=c)
tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred, target=target, class_index=c)
return tps, fps, tns, fns, sups
@ -164,6 +202,14 @@ def accuracy(
Return:
A Tensor with the classification score.
Example:
>>> x = torch.tensor([1, 2, 3])
>>> y = torch.tensor([0, 2, 3])
>>> accuracy(x, y)
tensor(0.6667)
"""
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target,
num_classes=num_classes)
@ -193,6 +239,16 @@ def confusion_matrix(
Return:
Tensor, confusion matrix C [num_classes, num_classes ]
Example:
>>> x = torch.tensor([1, 2, 3])
>>> y = torch.tensor([0, 2, 3])
>>> confusion_matrix(x, y)
tensor([[0., 1., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]])
"""
num_classes = get_num_classes(pred, target, None)
@ -229,10 +285,16 @@ def precision_recall(
Return:
Tensor with precision and recall
Example:
>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 2, 2])
>>> precision_recall(x, y)
(tensor(1.), tensor(0.8333))
"""
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred,
target=target,
num_classes=num_classes)
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes)
tps = tps.to(torch.float)
fps = fps.to(torch.float)
@ -268,6 +330,14 @@ def precision(
Return:
Tensor with precision.
Example:
>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 2, 2])
>>> precision(x, y)
tensor(1.)
"""
return precision_recall(pred=pred, target=target,
num_classes=num_classes, reduction=reduction)[0]
@ -295,6 +365,13 @@ def recall(
Return:
Tensor with recall.
Example:
>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 2, 2])
>>> recall(x, y)
tensor(0.8333)
"""
return precision_recall(pred=pred, target=target,
num_classes=num_classes, reduction=reduction)[1]
@ -329,6 +406,13 @@ def fbeta_score(
Return:
Tensor with the value of F-score. It is a value between 0-1.
Example:
>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 2, 2])
>>> fbeta_score(x, y, 0.2)
tensor(0.9877)
"""
prec, rec = precision_recall(pred=pred, target=target,
num_classes=num_classes,
@ -363,6 +447,13 @@ def f1_score(
Return:
Tensor containing F1-score
Example:
>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 2, 2])
>>> f1_score(x, y)
tensor(0.8889)
"""
return fbeta_score(pred=pred, target=target, beta=1.,
num_classes=num_classes, reduction=reduction)
@ -431,6 +522,19 @@ def roc(
Return:
[Tensor, Tensor, Tensor]: false-positive rate (fpr), true-positive rate (tpr), thresholds
Example:
>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 2, 2])
>>> fpr, tpr, thresholds = roc(x,y)
>>> fpr
tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000])
>>> tpr
tensor([0., 0., 0., 1., 1.])
>>> thresholds
tensor([4, 3, 2, 1, 0])
"""
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
sample_weight=sample_weight,
@ -473,6 +577,19 @@ def multiclass_roc(
Return:
[num_classes, Tensor, Tensor, Tensor]: returns roc for each class.
number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds
Example:
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE
((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])))
"""
num_classes = get_num_classes(pred, target, num_classes)
@ -503,6 +620,19 @@ def precision_recall_curve(
Return:
[Tensor, Tensor, Tensor]: precision, recall, thresholds
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> precision, recall, thresholds = precision_recall_curve(pred, target)
>>> precision
tensor([0.3333, 0.0000, 0.0000, 1.0000])
>>> recall
tensor([1., 0., 0., 0.])
>>> thresholds
tensor([1, 2, 3])
"""
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
sample_weight=sample_weight,
@ -547,7 +677,24 @@ def multiclass_precision_recall_curve(
num_classes: number of classes
Return:
[num_classes, Tensor, Tensor, Tensor]: number of classes, precision, recall, thresholds
[num_classes, Tensor, Tensor, Tensor]: number of classes, precision, recall, thresholds
Example:
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> nb_classes, precision, recall, thresholds = multiclass_precision_recall_curve(pred, target)
>>> nb_classes
(tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500]))
>>> precision
(tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500]))
>>> recall
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500]))
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500]))
"""
num_classes = get_num_classes(pred, target, num_classes)
@ -574,6 +721,13 @@ def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True):
Return:
AUC score (float)
Example:
>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 2, 2])
>>> auc(x, y)
tensor(4.)
"""
direction = 1.
@ -635,6 +789,13 @@ def auroc(
target: ground-truth labels
sample_weight: sample weights
pos_label: the label for the positive class (default: 1.)
Example:
>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 2, 2])
>>> auroc(x, y)
tensor(0.3333)
"""
@auc_decorator(reorder=True)
@ -650,6 +811,21 @@ def average_precision(
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
) -> torch.Tensor:
"""
Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
pos_label: the label for the positive class (default: 1.)
Example:
>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 2, 2])
>>> average_precision(x, y)
tensor(0.3333)
"""
precision, recall, _ = precision_recall_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=pos_label)
@ -667,6 +843,26 @@ def dice_score(
no_fg_score: float = 0.0,
reduction: str = 'elementwise_mean',
) -> torch.Tensor:
"""
Args:
pred: estimated probabilities
target: ground-truth labels
bg:
nan_score:
no_fg_score:
reduction:
Example:
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> average_precision(pred, target)
tensor(0.2500)
"""
n_classes = pred.shape[1]
bg = (1 - int(bool(bg)))
scores = torch.zeros(n_classes - bg, device=pred.device, dtype=torch.float32)

View File

@ -351,5 +351,6 @@ def test_dice_score(pred, target, expected):
score = dice_score(torch.tensor(pred), torch.tensor(target))
assert score == expected
# example data taken from
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py

View File

@ -27,7 +27,7 @@ from tests.base import EvalModelTemplate
def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
# set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
monkeypatch.setenv('TORCH_HOME', tmpdir)
monkeypatch.setenv('TORCH_HOME', str(tmpdir))
model = EvalModelTemplate()