Metrics docs (#2184)
* fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * add workers fix * add workers fix * add workers fix * add workers fix * add workers fix * add workers fix * add workers fix * add workers fix * add workers fix * add workers fix * Update docs/source/metrics.rst Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * Update docs/source/metrics.rst Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * add workers fix * add workers fix * add workers fix * doctests * add workers fix * add workers fix * fixes * fix docs * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * add workers fix * Update docs/source/metrics.rst Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * doctests * add workers fix * fix docs * fixes * fixes * fix doctests * Apply suggestions from code review * fix doctests * fix examples * bug * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fixes * fixes * fixes * fixes Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Nicki Skafte <nugginea@gmail.com>
This commit is contained in:
parent
e289e45120
commit
55fbcc00f6
|
@ -1,4 +1,318 @@
|
|||
.. automodule:: pytorch_lightning.metrics
|
||||
:members:
|
||||
:noindex:
|
||||
:exclude-members:
|
||||
.. testsetup:: *
|
||||
|
||||
from torch.nn import Module
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.metrics import TensorMetric, NumpyMetric
|
||||
|
||||
Metrics
|
||||
=======
|
||||
This is a general package for PyTorch Metrics. These can also be used with regular non-lightning PyTorch code.
|
||||
Metrics are used to monitor model performance.
|
||||
|
||||
In this package we provide two major pieces of functionality.
|
||||
|
||||
1. A Metric class you can use to implement metrics with built-in distributed (ddp) support which are device agnostic.
|
||||
2. A collection of popular metrics already implemented for you.
|
||||
|
||||
Example::
|
||||
|
||||
from pytorch_lightning.metrics.functional import accuracy
|
||||
|
||||
pred = torch.tensor([0, 1, 2, 3])
|
||||
target = torch.tensor([0, 1, 2, 2])
|
||||
|
||||
# calculates accuracy across all GPUs and all Nodes used in training
|
||||
accuracy(pred, target)
|
||||
|
||||
Out::
|
||||
|
||||
tensor(0.7500)
|
||||
|
||||
--------------
|
||||
|
||||
Implement a metric
|
||||
------------------
|
||||
You can implement metrics as either a PyTorch metric or a Numpy metric. Numpy metrics
|
||||
will slow down training, use PyTorch metrics when possible.
|
||||
|
||||
Use :class:`TensorMetric` to implement native PyTorch metrics. This class
|
||||
handles automated DDP syncing and converts all inputs and outputs to tensors.
|
||||
|
||||
Use :class:`NumpyMetric` to implement numpy metrics. This class
|
||||
handles automated DDP syncing and converts all inputs and outputs to tensors.
|
||||
|
||||
.. warning::
|
||||
Numpy metrics might slow down your training substantially,
|
||||
since every metric computation requires a GPU sync to convert tensors to numpy.
|
||||
|
||||
TensorMetric
|
||||
^^^^^^^^^^^^
|
||||
Here's an example showing how to implement a TensorMetric
|
||||
|
||||
.. testcode::
|
||||
|
||||
class RMSE(TensorMetric):
|
||||
def forward(self, x, y):
|
||||
return torch.sqrt(torch.mean(torch.pow(x-y, 2.0)))
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.metric.TensorMetric
|
||||
:noindex:
|
||||
|
||||
NumpyMetric
|
||||
^^^^^^^^^^^
|
||||
Here's an example showing how to implement a NumpyMetric
|
||||
|
||||
.. testcode::
|
||||
|
||||
class RMSE(NumpyMetric):
|
||||
def forward(self, x, y):
|
||||
return np.sqrt(np.mean(np.power(x-y, 2.0)))
|
||||
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.metric.NumpyMetric
|
||||
:noindex:
|
||||
|
||||
--------------
|
||||
|
||||
Class Metrics
|
||||
-------------
|
||||
The following are metrics which can be instantiated as part of a module definition (even with just
|
||||
plain PyTorch).
|
||||
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning.metrics import Accuracy
|
||||
|
||||
# Plain PyTorch
|
||||
class MyModule(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.metric = Accuracy()
|
||||
|
||||
def forward(self, x, y):
|
||||
y_hat = ...
|
||||
acc = self.metric(y_hat, y)
|
||||
|
||||
# PyTorch Lightning
|
||||
class MyModule(LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.metric = Accuracy()
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = ...
|
||||
acc = self.metric(y_hat, y)
|
||||
|
||||
These metrics even work when using distributed training:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = MyModule()
|
||||
trainer = Trainer(gpus=8, num_nodes=2)
|
||||
|
||||
# any metric automatically reduces across GPUs (even the ones you implement using Lightning)
|
||||
trainer.fit(model)
|
||||
|
||||
Accuracy
|
||||
^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.Accuracy
|
||||
:noindex:
|
||||
|
||||
AveragePrecision
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.AveragePrecision
|
||||
:noindex:
|
||||
|
||||
AUROC
|
||||
^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.AUROC
|
||||
:noindex:
|
||||
|
||||
ConfusionMatrix
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix
|
||||
:noindex:
|
||||
|
||||
DiceCoefficient
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.DiceCoefficient
|
||||
:noindex:
|
||||
|
||||
F1
|
||||
^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.F1
|
||||
:noindex:
|
||||
|
||||
FBeta
|
||||
^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.FBeta
|
||||
:noindex:
|
||||
|
||||
PrecisionRecall
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecall
|
||||
:noindex:
|
||||
|
||||
Precision
|
||||
^^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.Precision
|
||||
:noindex:
|
||||
|
||||
Recall
|
||||
^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.Recall
|
||||
:noindex:
|
||||
|
||||
ROC
|
||||
^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.ROC
|
||||
:noindex:
|
||||
|
||||
MulticlassROC
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.MulticlassROC
|
||||
:noindex:
|
||||
|
||||
MulticlassPrecisionRecall
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.MulticlassPrecisionRecall
|
||||
:noindex:
|
||||
|
||||
--------------
|
||||
|
||||
Functional Metrics
|
||||
------------------
|
||||
|
||||
accuracy (F)
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.accuracy
|
||||
:noindex:
|
||||
|
||||
auc (F)
|
||||
^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.auc
|
||||
:noindex:
|
||||
|
||||
auroc (F)
|
||||
^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.auroc
|
||||
:noindex:
|
||||
|
||||
average_precision (F)
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.average_precision
|
||||
:noindex:
|
||||
|
||||
confusion_matrix (F)
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.confusion_matrix
|
||||
:noindex:
|
||||
|
||||
dice_score (F)
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.dice_score
|
||||
:noindex:
|
||||
|
||||
f1_score (F)
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.f1_score
|
||||
:noindex:
|
||||
|
||||
fbeta_score (F)
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.fbeta_score
|
||||
:noindex:
|
||||
|
||||
multiclass_precision_recall_curve (F)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.multiclass_precision_recall_curve
|
||||
:noindex:
|
||||
|
||||
multiclass_roc (F)
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.multiclass_roc
|
||||
:noindex:
|
||||
|
||||
precision (F)
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.precision
|
||||
:noindex:
|
||||
|
||||
precision_recall (F)
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.precision_recall
|
||||
:noindex:
|
||||
|
||||
precision_recall_curve (F)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.precision_recall_curve
|
||||
:noindex:
|
||||
|
||||
recall (F)
|
||||
^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.recall
|
||||
:noindex:
|
||||
|
||||
roc (F)
|
||||
^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.roc
|
||||
:noindex:
|
||||
|
||||
stat_scores (F)
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.stat_scores
|
||||
:noindex:
|
||||
|
||||
stat_scores_multiple_classes (F)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.stat_scores_multiple_classes
|
||||
:noindex:
|
||||
|
||||
----------------
|
||||
|
||||
Metric pre-processing
|
||||
---------------------
|
||||
Metric
|
||||
|
||||
to_categorical (F)
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.to_categorical
|
||||
:noindex:
|
||||
|
||||
to_onehot (F)
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.to_onehot
|
||||
:noindex:
|
||||
|
|
|
@ -1,30 +1,15 @@
|
|||
"""
|
||||
Metrics
|
||||
=======
|
||||
|
||||
Metrics are generally used to monitor model performance.
|
||||
|
||||
The following package aims to provide the most convenient ones as well
|
||||
as a structure to implement your custom metrics for all the fancy research
|
||||
you want to do.
|
||||
|
||||
For native PyTorch implementations of metrics, it is recommended to use
|
||||
the :class:`TensorMetric` which handles automated DDP syncing and conversions
|
||||
to tensors for all inputs and outputs.
|
||||
|
||||
If your metrics implementation works on numpy, just use the
|
||||
:class:`NumpyMetric`, which handles the automated conversion of
|
||||
inputs to and outputs from numpy as well as automated ddp syncing.
|
||||
|
||||
.. warning:: Employing numpy in your metric calculation might slow
|
||||
down your training substantially, since every metric computation
|
||||
requires a GPU sync to convert tensors to numpy.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
|
||||
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
|
||||
from pytorch_lightning.metrics.sklearn import (
|
||||
SklearnMetric, Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta,
|
||||
Precision, Recall, PrecisionRecallCurve, ROC, AUROC)
|
||||
SklearnMetric,
|
||||
Accuracy,
|
||||
AveragePrecision,
|
||||
AUC,
|
||||
ConfusionMatrix,
|
||||
F1,
|
||||
FBeta,
|
||||
Precision,
|
||||
Recall,
|
||||
PrecisionRecallCurve,
|
||||
ROC,
|
||||
AUROC)
|
||||
|
|
|
@ -60,6 +60,14 @@ class Accuracy(TensorMetric):
|
|||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = Accuracy()
|
||||
>>> metric(pred, target)
|
||||
tensor(0.7500)
|
||||
|
||||
"""
|
||||
super().__init__(name='accuracy',
|
||||
reduce_group=reduce_group,
|
||||
|
@ -100,6 +108,17 @@ class ConfusionMatrix(TensorMetric):
|
|||
normalize: whether to compute a normalized confusion matrix
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 2])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = ConfusionMatrix()
|
||||
>>> metric(pred, target)
|
||||
tensor([[1., 0., 0.],
|
||||
[0., 1., 0.],
|
||||
[0., 0., 2.]])
|
||||
|
||||
"""
|
||||
super().__init__(name='confusion_matrix',
|
||||
reduce_group=reduce_group,
|
||||
|
@ -138,6 +157,19 @@ class PrecisionRecall(TensorCollectionMetric):
|
|||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = PrecisionRecall()
|
||||
>>> prec, recall, thr = metric(pred, target)
|
||||
>>> prec
|
||||
tensor([0.3333, 0.0000, 0.0000, 1.0000])
|
||||
>>> recall
|
||||
tensor([1., 0., 0., 0.])
|
||||
>>> thr
|
||||
tensor([1., 2., 3.])
|
||||
|
||||
"""
|
||||
super().__init__(name='precision_recall_curve',
|
||||
reduce_group=reduce_group,
|
||||
|
@ -192,11 +224,18 @@ class Precision(TensorMetric):
|
|||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = Precision()
|
||||
>>> metric(pred, target)
|
||||
tensor(1.)
|
||||
|
||||
"""
|
||||
super().__init__(name='precision',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.reduction = reduction
|
||||
|
||||
|
@ -239,6 +278,14 @@ class Recall(TensorMetric):
|
|||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = Recall()
|
||||
>>> metric(pred, target)
|
||||
tensor(0.8333)
|
||||
|
||||
"""
|
||||
super().__init__(name='recall',
|
||||
reduce_group=reduce_group,
|
||||
|
@ -281,6 +328,14 @@ class AveragePrecision(TensorMetric):
|
|||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = AveragePrecision()
|
||||
>>> metric(pred, target)
|
||||
tensor(0.3333)
|
||||
|
||||
"""
|
||||
super().__init__(name='AP',
|
||||
reduce_group=reduce_group,
|
||||
|
@ -327,6 +382,14 @@ class AUROC(TensorMetric):
|
|||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = AUROC()
|
||||
>>> metric(pred, target)
|
||||
tensor(0.3333)
|
||||
|
||||
"""
|
||||
super().__init__(name='auroc',
|
||||
reduce_group=reduce_group,
|
||||
|
@ -379,6 +442,14 @@ class FBeta(TensorMetric):
|
|||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = FBeta(0.25)
|
||||
>>> metric(pred, target)
|
||||
tensor(0.9815)
|
||||
|
||||
"""
|
||||
super().__init__(name='fbeta',
|
||||
reduce_group=reduce_group,
|
||||
|
@ -425,6 +496,14 @@ class F1(TensorMetric):
|
|||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = F1()
|
||||
>>> metric(pred, target)
|
||||
tensor(0.8889)
|
||||
|
||||
"""
|
||||
super().__init__(name='f1',
|
||||
reduce_group=reduce_group,
|
||||
|
@ -466,6 +545,19 @@ class ROC(TensorCollectionMetric):
|
|||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = ROC()
|
||||
>>> fp, tp, thresholds = metric(pred, target)
|
||||
>>> fp
|
||||
tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000])
|
||||
>>> tp
|
||||
tensor([0., 0., 0., 1., 1.])
|
||||
>>> thresholds
|
||||
tensor([4., 3., 2., 1., 0.])
|
||||
|
||||
"""
|
||||
super().__init__(name='roc',
|
||||
reduce_group=reduce_group,
|
||||
|
@ -519,6 +611,20 @@ class MulticlassROC(TensorCollectionMetric):
|
|||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||
... [0.05, 0.85, 0.05, 0.05],
|
||||
... [0.05, 0.05, 0.85, 0.05],
|
||||
... [0.05, 0.05, 0.05, 0.85]])
|
||||
>>> target = torch.tensor([0, 1, 3, 2])
|
||||
>>> metric = MulticlassROC()
|
||||
>>> classes_roc = metric(pred, target)
|
||||
>>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE
|
||||
((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
|
||||
(tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
|
||||
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])),
|
||||
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])))
|
||||
"""
|
||||
super().__init__(name='multiclass_roc',
|
||||
reduce_group=reduce_group,
|
||||
|
@ -535,7 +641,7 @@ class MulticlassROC(TensorCollectionMetric):
|
|||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
pred: predicted probability for each label
|
||||
target: groundtruth labels
|
||||
sample_weight: Weights for each sample defining the sample's impact on the score
|
||||
|
||||
|
@ -569,6 +675,21 @@ class MulticlassPrecisionRecall(TensorCollectionMetric):
|
|||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||
... [0.05, 0.85, 0.05, 0.05],
|
||||
... [0.05, 0.05, 0.85, 0.05],
|
||||
... [0.05, 0.05, 0.05, 0.85]])
|
||||
>>> target = torch.tensor([0, 1, 3, 2])
|
||||
>>> metric = MulticlassPrecisionRecall()
|
||||
>>> classes_pr = metric(pred, target)
|
||||
>>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE
|
||||
((tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])),
|
||||
(tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])),
|
||||
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])),
|
||||
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])))
|
||||
|
||||
"""
|
||||
super().__init__(name='multiclass_precision_recall_curve',
|
||||
reduce_group=reduce_group,
|
||||
|
@ -586,7 +707,7 @@ class MulticlassPrecisionRecall(TensorCollectionMetric):
|
|||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
pred: predicted probability for each label
|
||||
target: groundtruth labels
|
||||
sample_weight: Weights for each sample defining the sample's impact on the score
|
||||
|
||||
|
@ -623,6 +744,20 @@ class DiceCoefficient(TensorMetric):
|
|||
- sum: add elements
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
Example:
|
||||
|
||||
.. testcode:
|
||||
|
||||
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||
... [0.05, 0.85, 0.05, 0.05],
|
||||
... [0.05, 0.05, 0.85, 0.05],
|
||||
... [0.05, 0.05, 0.05, 0.85]])
|
||||
>>> target = torch.tensor([0, 1, 3, 2])
|
||||
>>> metric = DiceCoefficient()
|
||||
>>> classes_pr = metric(pred, target)
|
||||
>>> metric(pred, target)
|
||||
tensor(0.3333)
|
||||
"""
|
||||
super().__init__(name='dice',
|
||||
reduce_group=reduce_group,
|
||||
|
@ -638,7 +773,7 @@ class DiceCoefficient(TensorMetric):
|
|||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
pred: predicted probability for each label
|
||||
target: groundtruth labels
|
||||
|
||||
Return:
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
from pytorch_lightning.metrics.functional.classification import (
|
||||
accuracy,
|
||||
auc,
|
||||
auroc,
|
||||
average_precision,
|
||||
confusion_matrix,
|
||||
dice_score,
|
||||
f1_score,
|
||||
fbeta_score,
|
||||
multiclass_precision_recall_curve,
|
||||
multiclass_roc,
|
||||
precision,
|
||||
precision_recall,
|
||||
precision_recall_curve,
|
||||
recall,
|
||||
roc,
|
||||
stat_scores,
|
||||
stat_scores_multiple_classes,
|
||||
to_categorical,
|
||||
to_onehot
|
||||
)
|
|
@ -21,6 +21,15 @@ def to_onehot(
|
|||
|
||||
Output:
|
||||
A sparse label tensor with shape [N, C, d1, d2, ...]
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([1, 2, 3])
|
||||
>>> to_onehot(x)
|
||||
tensor([[0, 1, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1]])
|
||||
|
||||
"""
|
||||
if n_classes is None:
|
||||
n_classes = int(tensor.max().detach().item() + 1)
|
||||
|
@ -41,6 +50,13 @@ def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
|
|||
|
||||
Return:
|
||||
A tensor with categorical labels [N, d2, ...]
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]])
|
||||
>>> to_categorical(x)
|
||||
tensor([1, 0])
|
||||
|
||||
"""
|
||||
return torch.argmax(tensor, dim=argmax_dim)
|
||||
|
||||
|
@ -65,7 +81,8 @@ def get_num_classes(
|
|||
if pred.ndim > target.ndim:
|
||||
num_classes = pred.size(1)
|
||||
else:
|
||||
num_classes = int(target.max().detach().item() + 1)
|
||||
num_target_classes = int(target.max().detach().item() + 1)
|
||||
num_classes = num_target_classes
|
||||
return num_classes
|
||||
|
||||
|
||||
|
@ -88,6 +105,18 @@ def stat_scores(
|
|||
Return:
|
||||
Tensors in the following order: True Positive, False Positive, True Negative, False Negative
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([1, 2, 3])
|
||||
>>> y = torch.tensor([0, 2, 3])
|
||||
>>> tp, fp, tn, fn, sup = stat_scores(x, y, class_index=1)
|
||||
>>> stat_scores(x, y, class_index=1) # doctest: +NORMALIZE_WHITESPACE
|
||||
(tensor(0),
|
||||
tensor(1),
|
||||
tensor(2),
|
||||
tensor(0),
|
||||
tensor(0))
|
||||
|
||||
"""
|
||||
if pred.ndim == target.ndim + 1:
|
||||
pred = to_categorical(pred, argmax_dim=argmax_dim)
|
||||
|
@ -122,6 +151,17 @@ def stat_scores_multiple_classes(
|
|||
Return:
|
||||
Returns tensors for: tp, fp, tn, fn
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([1, 2, 3])
|
||||
>>> y = torch.tensor([0, 2, 3])
|
||||
>>> tps, fps, tns, fns, sups = stat_scores_multiple_classes(x, y)
|
||||
>>> stat_scores_multiple_classes(x, y) # doctest: +NORMALIZE_WHITESPACE
|
||||
(tensor([0., 0., 1., 1.]),
|
||||
tensor([0., 1., 0., 0.]),
|
||||
tensor([2., 2., 2., 2.]),
|
||||
tensor([1., 0., 0., 0.]),
|
||||
tensor([1., 0., 1., 1.]))
|
||||
"""
|
||||
num_classes = get_num_classes(pred=pred, target=target,
|
||||
num_classes=num_classes)
|
||||
|
@ -135,9 +175,7 @@ def stat_scores_multiple_classes(
|
|||
fns = torch.zeros((num_classes,), device=pred.device)
|
||||
sups = torch.zeros((num_classes,), device=pred.device)
|
||||
for c in range(num_classes):
|
||||
tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred,
|
||||
target=target,
|
||||
class_index=c)
|
||||
tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred, target=target, class_index=c)
|
||||
|
||||
return tps, fps, tns, fns, sups
|
||||
|
||||
|
@ -164,6 +202,14 @@ def accuracy(
|
|||
|
||||
Return:
|
||||
A Tensor with the classification score.
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([1, 2, 3])
|
||||
>>> y = torch.tensor([0, 2, 3])
|
||||
>>> accuracy(x, y)
|
||||
tensor(0.6667)
|
||||
|
||||
"""
|
||||
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target,
|
||||
num_classes=num_classes)
|
||||
|
@ -193,6 +239,16 @@ def confusion_matrix(
|
|||
|
||||
Return:
|
||||
Tensor, confusion matrix C [num_classes, num_classes ]
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([1, 2, 3])
|
||||
>>> y = torch.tensor([0, 2, 3])
|
||||
>>> confusion_matrix(x, y)
|
||||
tensor([[0., 1., 0., 0.],
|
||||
[0., 0., 0., 0.],
|
||||
[0., 0., 1., 0.],
|
||||
[0., 0., 0., 1.]])
|
||||
"""
|
||||
num_classes = get_num_classes(pred, target, None)
|
||||
|
||||
|
@ -229,10 +285,16 @@ def precision_recall(
|
|||
|
||||
Return:
|
||||
Tensor with precision and recall
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> precision_recall(x, y)
|
||||
(tensor(1.), tensor(0.8333))
|
||||
|
||||
"""
|
||||
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred,
|
||||
target=target,
|
||||
num_classes=num_classes)
|
||||
tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes)
|
||||
|
||||
tps = tps.to(torch.float)
|
||||
fps = fps.to(torch.float)
|
||||
|
@ -268,6 +330,14 @@ def precision(
|
|||
|
||||
Return:
|
||||
Tensor with precision.
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> precision(x, y)
|
||||
tensor(1.)
|
||||
|
||||
"""
|
||||
return precision_recall(pred=pred, target=target,
|
||||
num_classes=num_classes, reduction=reduction)[0]
|
||||
|
@ -295,6 +365,13 @@ def recall(
|
|||
|
||||
Return:
|
||||
Tensor with recall.
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> recall(x, y)
|
||||
tensor(0.8333)
|
||||
"""
|
||||
return precision_recall(pred=pred, target=target,
|
||||
num_classes=num_classes, reduction=reduction)[1]
|
||||
|
@ -329,6 +406,13 @@ def fbeta_score(
|
|||
|
||||
Return:
|
||||
Tensor with the value of F-score. It is a value between 0-1.
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> fbeta_score(x, y, 0.2)
|
||||
tensor(0.9877)
|
||||
"""
|
||||
prec, rec = precision_recall(pred=pred, target=target,
|
||||
num_classes=num_classes,
|
||||
|
@ -363,6 +447,13 @@ def f1_score(
|
|||
|
||||
Return:
|
||||
Tensor containing F1-score
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> f1_score(x, y)
|
||||
tensor(0.8889)
|
||||
"""
|
||||
return fbeta_score(pred=pred, target=target, beta=1.,
|
||||
num_classes=num_classes, reduction=reduction)
|
||||
|
@ -431,6 +522,19 @@ def roc(
|
|||
|
||||
Return:
|
||||
[Tensor, Tensor, Tensor]: false-positive rate (fpr), true-positive rate (tpr), thresholds
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> fpr, tpr, thresholds = roc(x,y)
|
||||
>>> fpr
|
||||
tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000])
|
||||
>>> tpr
|
||||
tensor([0., 0., 0., 1., 1.])
|
||||
>>> thresholds
|
||||
tensor([4, 3, 2, 1, 0])
|
||||
|
||||
"""
|
||||
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
|
||||
sample_weight=sample_weight,
|
||||
|
@ -473,6 +577,19 @@ def multiclass_roc(
|
|||
Return:
|
||||
[num_classes, Tensor, Tensor, Tensor]: returns roc for each class.
|
||||
number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||
... [0.05, 0.85, 0.05, 0.05],
|
||||
... [0.05, 0.05, 0.85, 0.05],
|
||||
... [0.05, 0.05, 0.05, 0.85]])
|
||||
>>> target = torch.tensor([0, 1, 3, 2])
|
||||
>>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE
|
||||
((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
|
||||
(tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
|
||||
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])),
|
||||
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])))
|
||||
"""
|
||||
num_classes = get_num_classes(pred, target, num_classes)
|
||||
|
||||
|
@ -503,6 +620,19 @@ def precision_recall_curve(
|
|||
|
||||
Return:
|
||||
[Tensor, Tensor, Tensor]: precision, recall, thresholds
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> precision, recall, thresholds = precision_recall_curve(pred, target)
|
||||
>>> precision
|
||||
tensor([0.3333, 0.0000, 0.0000, 1.0000])
|
||||
>>> recall
|
||||
tensor([1., 0., 0., 0.])
|
||||
>>> thresholds
|
||||
tensor([1, 2, 3])
|
||||
|
||||
"""
|
||||
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
|
||||
sample_weight=sample_weight,
|
||||
|
@ -547,7 +677,24 @@ def multiclass_precision_recall_curve(
|
|||
num_classes: number of classes
|
||||
|
||||
Return:
|
||||
[num_classes, Tensor, Tensor, Tensor]: number of classes, precision, recall, thresholds
|
||||
[num_classes, Tensor, Tensor, Tensor]: number of classes, precision, recall, thresholds
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||
... [0.05, 0.85, 0.05, 0.05],
|
||||
... [0.05, 0.05, 0.85, 0.05],
|
||||
... [0.05, 0.05, 0.05, 0.85]])
|
||||
>>> target = torch.tensor([0, 1, 3, 2])
|
||||
>>> nb_classes, precision, recall, thresholds = multiclass_precision_recall_curve(pred, target)
|
||||
>>> nb_classes
|
||||
(tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500]))
|
||||
>>> precision
|
||||
(tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500]))
|
||||
>>> recall
|
||||
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500]))
|
||||
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
|
||||
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500]))
|
||||
"""
|
||||
num_classes = get_num_classes(pred, target, num_classes)
|
||||
|
||||
|
@ -574,6 +721,13 @@ def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True):
|
|||
|
||||
Return:
|
||||
AUC score (float)
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> auc(x, y)
|
||||
tensor(4.)
|
||||
"""
|
||||
direction = 1.
|
||||
|
||||
|
@ -635,6 +789,13 @@ def auroc(
|
|||
target: ground-truth labels
|
||||
sample_weight: sample weights
|
||||
pos_label: the label for the positive class (default: 1.)
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> auroc(x, y)
|
||||
tensor(0.3333)
|
||||
"""
|
||||
|
||||
@auc_decorator(reorder=True)
|
||||
|
@ -650,6 +811,21 @@ def average_precision(
|
|||
sample_weight: Optional[Sequence] = None,
|
||||
pos_label: int = 1.,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
sample_weight: sample weights
|
||||
pos_label: the label for the positive class (default: 1.)
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0, 1, 2, 3])
|
||||
>>> y = torch.tensor([0, 1, 2, 2])
|
||||
>>> average_precision(x, y)
|
||||
tensor(0.3333)
|
||||
"""
|
||||
precision, recall, _ = precision_recall_curve(pred=pred, target=target,
|
||||
sample_weight=sample_weight,
|
||||
pos_label=pos_label)
|
||||
|
@ -667,6 +843,26 @@ def dice_score(
|
|||
no_fg_score: float = 0.0,
|
||||
reduction: str = 'elementwise_mean',
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
pred: estimated probabilities
|
||||
target: ground-truth labels
|
||||
bg:
|
||||
nan_score:
|
||||
no_fg_score:
|
||||
reduction:
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||
... [0.05, 0.85, 0.05, 0.05],
|
||||
... [0.05, 0.05, 0.85, 0.05],
|
||||
... [0.05, 0.05, 0.05, 0.85]])
|
||||
>>> target = torch.tensor([0, 1, 3, 2])
|
||||
>>> average_precision(pred, target)
|
||||
tensor(0.2500)
|
||||
|
||||
"""
|
||||
n_classes = pred.shape[1]
|
||||
bg = (1 - int(bool(bg)))
|
||||
scores = torch.zeros(n_classes - bg, device=pred.device, dtype=torch.float32)
|
||||
|
|
|
@ -351,5 +351,6 @@ def test_dice_score(pred, target, expected):
|
|||
score = dice_score(torch.tensor(pred), torch.tensor(target))
|
||||
assert score == expected
|
||||
|
||||
|
||||
# example data taken from
|
||||
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
|
||||
|
|
|
@ -27,7 +27,7 @@ from tests.base import EvalModelTemplate
|
|||
def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
|
||||
"""Tests use case where trainer saves the model, and user loads it from tags independently."""
|
||||
# set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
|
||||
monkeypatch.setenv('TORCH_HOME', tmpdir)
|
||||
monkeypatch.setenv('TORCH_HOME', str(tmpdir))
|
||||
|
||||
model = EvalModelTemplate()
|
||||
|
||||
|
|
Loading…
Reference in New Issue