diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 6f70a3c73f..8d7322d470 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -1,4 +1,318 @@ -.. automodule:: pytorch_lightning.metrics - :members: - :noindex: - :exclude-members: +.. testsetup:: * + + from torch.nn import Module + from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.metrics import TensorMetric, NumpyMetric + +Metrics +======= +This is a general package for PyTorch Metrics. These can also be used with regular non-lightning PyTorch code. +Metrics are used to monitor model performance. + +In this package we provide two major pieces of functionality. + +1. A Metric class you can use to implement metrics with built-in distributed (ddp) support which are device agnostic. +2. A collection of popular metrics already implemented for you. + +Example:: + + from pytorch_lightning.metrics.functional import accuracy + + pred = torch.tensor([0, 1, 2, 3]) + target = torch.tensor([0, 1, 2, 2]) + + # calculates accuracy across all GPUs and all Nodes used in training + accuracy(pred, target) + +Out:: + + tensor(0.7500) + +-------------- + +Implement a metric +------------------ +You can implement metrics as either a PyTorch metric or a Numpy metric. Numpy metrics +will slow down training, use PyTorch metrics when possible. + +Use :class:`TensorMetric` to implement native PyTorch metrics. This class +handles automated DDP syncing and converts all inputs and outputs to tensors. + +Use :class:`NumpyMetric` to implement numpy metrics. This class +handles automated DDP syncing and converts all inputs and outputs to tensors. + +.. warning:: + Numpy metrics might slow down your training substantially, + since every metric computation requires a GPU sync to convert tensors to numpy. + +TensorMetric +^^^^^^^^^^^^ +Here's an example showing how to implement a TensorMetric + +.. testcode:: + + class RMSE(TensorMetric): + def forward(self, x, y): + return torch.sqrt(torch.mean(torch.pow(x-y, 2.0))) + +.. autoclass:: pytorch_lightning.metrics.metric.TensorMetric + :noindex: + +NumpyMetric +^^^^^^^^^^^ +Here's an example showing how to implement a NumpyMetric + +.. testcode:: + + class RMSE(NumpyMetric): + def forward(self, x, y): + return np.sqrt(np.mean(np.power(x-y, 2.0))) + + +.. autoclass:: pytorch_lightning.metrics.metric.NumpyMetric + :noindex: + +-------------- + +Class Metrics +------------- +The following are metrics which can be instantiated as part of a module definition (even with just +plain PyTorch). + +.. testcode:: + + from pytorch_lightning.metrics import Accuracy + + # Plain PyTorch + class MyModule(Module): + def __init__(self): + super().__init__() + self.metric = Accuracy() + + def forward(self, x, y): + y_hat = ... + acc = self.metric(y_hat, y) + + # PyTorch Lightning + class MyModule(LightningModule): + def __init__(self): + super().__init__() + self.metric = Accuracy() + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = ... + acc = self.metric(y_hat, y) + +These metrics even work when using distributed training: + +.. code-block:: python + + model = MyModule() + trainer = Trainer(gpus=8, num_nodes=2) + + # any metric automatically reduces across GPUs (even the ones you implement using Lightning) + trainer.fit(model) + +Accuracy +^^^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.classification.Accuracy + :noindex: + +AveragePrecision +^^^^^^^^^^^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.classification.AveragePrecision + :noindex: + +AUROC +^^^^^ + +.. autoclass:: pytorch_lightning.metrics.classification.AUROC + :noindex: + +ConfusionMatrix +^^^^^^^^^^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix + :noindex: + +DiceCoefficient +^^^^^^^^^^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.classification.DiceCoefficient + :noindex: + +F1 +^^ + +.. autoclass:: pytorch_lightning.metrics.classification.F1 + :noindex: + +FBeta +^^^^^ + +.. autoclass:: pytorch_lightning.metrics.classification.FBeta + :noindex: + +PrecisionRecall +^^^^^^^^^^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecall + :noindex: + +Precision +^^^^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.classification.Precision + :noindex: + +Recall +^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.classification.Recall + :noindex: + +ROC +^^^ + +.. autoclass:: pytorch_lightning.metrics.classification.ROC + :noindex: + +MulticlassROC +^^^^^^^^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.classification.MulticlassROC + :noindex: + +MulticlassPrecisionRecall +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.classification.MulticlassPrecisionRecall + :noindex: + +-------------- + +Functional Metrics +------------------ + +accuracy (F) +^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.accuracy + :noindex: + +auc (F) +^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.auc + :noindex: + +auroc (F) +^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.auroc + :noindex: + +average_precision (F) +^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.average_precision + :noindex: + +confusion_matrix (F) +^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.confusion_matrix + :noindex: + +dice_score (F) +^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.dice_score + :noindex: + +f1_score (F) +^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.f1_score + :noindex: + +fbeta_score (F) +^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.fbeta_score + :noindex: + +multiclass_precision_recall_curve (F) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.multiclass_precision_recall_curve + :noindex: + +multiclass_roc (F) +^^^^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.multiclass_roc + :noindex: + +precision (F) +^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.precision + :noindex: + +precision_recall (F) +^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.precision_recall + :noindex: + +precision_recall_curve (F) +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.precision_recall_curve + :noindex: + +recall (F) +^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.recall + :noindex: + +roc (F) +^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.roc + :noindex: + +stat_scores (F) +^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.stat_scores + :noindex: + +stat_scores_multiple_classes (F) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.stat_scores_multiple_classes + :noindex: + +---------------- + +Metric pre-processing +--------------------- +Metric + +to_categorical (F) +^^^^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.to_categorical + :noindex: + +to_onehot (F) +^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.functional.to_onehot + :noindex: diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 64ca41729d..ac026c3a74 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -1,30 +1,15 @@ -""" -Metrics -======= - -Metrics are generally used to monitor model performance. - -The following package aims to provide the most convenient ones as well -as a structure to implement your custom metrics for all the fancy research -you want to do. - -For native PyTorch implementations of metrics, it is recommended to use -the :class:`TensorMetric` which handles automated DDP syncing and conversions -to tensors for all inputs and outputs. - -If your metrics implementation works on numpy, just use the -:class:`NumpyMetric`, which handles the automated conversion of -inputs to and outputs from numpy as well as automated ddp syncing. - -.. warning:: Employing numpy in your metric calculation might slow - down your training substantially, since every metric computation - requires a GPU sync to convert tensors to numpy. - - -""" - from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric from pytorch_lightning.metrics.sklearn import ( - SklearnMetric, Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta, - Precision, Recall, PrecisionRecallCurve, ROC, AUROC) + SklearnMetric, + Accuracy, + AveragePrecision, + AUC, + ConfusionMatrix, + F1, + FBeta, + Precision, + Recall, + PrecisionRecallCurve, + ROC, + AUROC) diff --git a/pytorch_lightning/metrics/classification.py b/pytorch_lightning/metrics/classification.py index 3e02a8735b..db4318ed88 100644 --- a/pytorch_lightning/metrics/classification.py +++ b/pytorch_lightning/metrics/classification.py @@ -60,6 +60,14 @@ class Accuracy(TensorMetric): reduce_group: the process group to reduce metric results from DDP reduce_op: the operation to perform for ddp reduction + Example: + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 2, 2]) + >>> metric = Accuracy() + >>> metric(pred, target) + tensor(0.7500) + """ super().__init__(name='accuracy', reduce_group=reduce_group, @@ -100,6 +108,17 @@ class ConfusionMatrix(TensorMetric): normalize: whether to compute a normalized confusion matrix reduce_group: the process group to reduce metric results from DDP reduce_op: the operation to perform for ddp reduction + + Example: + + >>> pred = torch.tensor([0, 1, 2, 2]) + >>> target = torch.tensor([0, 1, 2, 2]) + >>> metric = ConfusionMatrix() + >>> metric(pred, target) + tensor([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 2.]]) + """ super().__init__(name='confusion_matrix', reduce_group=reduce_group, @@ -138,6 +157,19 @@ class PrecisionRecall(TensorCollectionMetric): reduce_group: the process group to reduce metric results from DDP reduce_op: the operation to perform for ddp reduction + Example: + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 2, 2]) + >>> metric = PrecisionRecall() + >>> prec, recall, thr = metric(pred, target) + >>> prec + tensor([0.3333, 0.0000, 0.0000, 1.0000]) + >>> recall + tensor([1., 0., 0., 0.]) + >>> thr + tensor([1., 2., 3.]) + """ super().__init__(name='precision_recall_curve', reduce_group=reduce_group, @@ -192,11 +224,18 @@ class Precision(TensorMetric): reduce_group: the process group to reduce metric results from DDP reduce_op: the operation to perform for ddp reduction + Example: + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 2, 2]) + >>> metric = Precision() + >>> metric(pred, target) + tensor(1.) + """ super().__init__(name='precision', reduce_group=reduce_group, reduce_op=reduce_op) - self.num_classes = num_classes self.reduction = reduction @@ -239,6 +278,14 @@ class Recall(TensorMetric): reduce_group: the process group to reduce metric results from DDP reduce_op: the operation to perform for ddp reduction + Example: + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 2, 2]) + >>> metric = Recall() + >>> metric(pred, target) + tensor(0.8333) + """ super().__init__(name='recall', reduce_group=reduce_group, @@ -281,6 +328,14 @@ class AveragePrecision(TensorMetric): reduce_group: the process group to reduce metric results from DDP reduce_op: the operation to perform for ddp reduction + Example: + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 2, 2]) + >>> metric = AveragePrecision() + >>> metric(pred, target) + tensor(0.3333) + """ super().__init__(name='AP', reduce_group=reduce_group, @@ -327,6 +382,14 @@ class AUROC(TensorMetric): reduce_group: the process group to reduce metric results from DDP reduce_op: the operation to perform for ddp reduction + Example: + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 2, 2]) + >>> metric = AUROC() + >>> metric(pred, target) + tensor(0.3333) + """ super().__init__(name='auroc', reduce_group=reduce_group, @@ -379,6 +442,14 @@ class FBeta(TensorMetric): reduce_group: the process group to reduce metric results from DDP reduce_op: the operation to perform for ddp reduction + Example: + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 2, 2]) + >>> metric = FBeta(0.25) + >>> metric(pred, target) + tensor(0.9815) + """ super().__init__(name='fbeta', reduce_group=reduce_group, @@ -425,6 +496,14 @@ class F1(TensorMetric): reduce_group: the process group to reduce metric results from DDP reduce_op: the operation to perform for ddp reduction + Example: + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 2, 2]) + >>> metric = F1() + >>> metric(pred, target) + tensor(0.8889) + """ super().__init__(name='f1', reduce_group=reduce_group, @@ -466,6 +545,19 @@ class ROC(TensorCollectionMetric): reduce_group: the process group to reduce metric results from DDP reduce_op: the operation to perform for ddp reduction + Example: + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 2, 2]) + >>> metric = ROC() + >>> fp, tp, thresholds = metric(pred, target) + >>> fp + tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]) + >>> tp + tensor([0., 0., 0., 1., 1.]) + >>> thresholds + tensor([4., 3., 2., 1., 0.]) + """ super().__init__(name='roc', reduce_group=reduce_group, @@ -519,6 +611,20 @@ class MulticlassROC(TensorCollectionMetric): reduce_group: the process group to reduce metric results from DDP reduce_op: the operation to perform for ddp reduction + Example: + + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> metric = MulticlassROC() + >>> classes_roc = metric(pred, target) + >>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE + ((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), + (tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), + (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])), + (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500]))) """ super().__init__(name='multiclass_roc', reduce_group=reduce_group, @@ -535,7 +641,7 @@ class MulticlassROC(TensorCollectionMetric): Actual metric computation Args: - pred: predicted labels + pred: predicted probability for each label target: groundtruth labels sample_weight: Weights for each sample defining the sample's impact on the score @@ -569,6 +675,21 @@ class MulticlassPrecisionRecall(TensorCollectionMetric): reduce_group: the process group to reduce metric results from DDP reduce_op: the operation to perform for ddp reduction + Example: + + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> metric = MulticlassPrecisionRecall() + >>> classes_pr = metric(pred, target) + >>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE + ((tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])), + (tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])), + (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])), + (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500]))) + """ super().__init__(name='multiclass_precision_recall_curve', reduce_group=reduce_group, @@ -586,7 +707,7 @@ class MulticlassPrecisionRecall(TensorCollectionMetric): Actual metric computation Args: - pred: predicted labels + pred: predicted probability for each label target: groundtruth labels sample_weight: Weights for each sample defining the sample's impact on the score @@ -623,6 +744,20 @@ class DiceCoefficient(TensorMetric): - sum: add elements reduce_group: the process group to reduce metric results from DDP reduce_op: the operation to perform for ddp reduction + + Example: + + .. testcode: + + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> metric = DiceCoefficient() + >>> classes_pr = metric(pred, target) + >>> metric(pred, target) + tensor(0.3333) """ super().__init__(name='dice', reduce_group=reduce_group, @@ -638,7 +773,7 @@ class DiceCoefficient(TensorMetric): Actual metric computation Args: - pred: predicted labels + pred: predicted probability for each label target: groundtruth labels Return: diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index e69de29bb2..2c8b8a85a9 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -0,0 +1,21 @@ +from pytorch_lightning.metrics.functional.classification import ( + accuracy, + auc, + auroc, + average_precision, + confusion_matrix, + dice_score, + f1_score, + fbeta_score, + multiclass_precision_recall_curve, + multiclass_roc, + precision, + precision_recall, + precision_recall_curve, + recall, + roc, + stat_scores, + stat_scores_multiple_classes, + to_categorical, + to_onehot +) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 64d821fbb3..ddb7746313 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -21,6 +21,15 @@ def to_onehot( Output: A sparse label tensor with shape [N, C, d1, d2, ...] + + Example: + + >>> x = torch.tensor([1, 2, 3]) + >>> to_onehot(x) + tensor([[0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + """ if n_classes is None: n_classes = int(tensor.max().detach().item() + 1) @@ -41,6 +50,13 @@ def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: Return: A tensor with categorical labels [N, d2, ...] + + Example: + + >>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) + >>> to_categorical(x) + tensor([1, 0]) + """ return torch.argmax(tensor, dim=argmax_dim) @@ -65,7 +81,8 @@ def get_num_classes( if pred.ndim > target.ndim: num_classes = pred.size(1) else: - num_classes = int(target.max().detach().item() + 1) + num_target_classes = int(target.max().detach().item() + 1) + num_classes = num_target_classes return num_classes @@ -88,6 +105,18 @@ def stat_scores( Return: Tensors in the following order: True Positive, False Positive, True Negative, False Negative + Example: + + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([0, 2, 3]) + >>> tp, fp, tn, fn, sup = stat_scores(x, y, class_index=1) + >>> stat_scores(x, y, class_index=1) # doctest: +NORMALIZE_WHITESPACE + (tensor(0), + tensor(1), + tensor(2), + tensor(0), + tensor(0)) + """ if pred.ndim == target.ndim + 1: pred = to_categorical(pred, argmax_dim=argmax_dim) @@ -122,6 +151,17 @@ def stat_scores_multiple_classes( Return: Returns tensors for: tp, fp, tn, fn + Example: + + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([0, 2, 3]) + >>> tps, fps, tns, fns, sups = stat_scores_multiple_classes(x, y) + >>> stat_scores_multiple_classes(x, y) # doctest: +NORMALIZE_WHITESPACE + (tensor([0., 0., 1., 1.]), + tensor([0., 1., 0., 0.]), + tensor([2., 2., 2., 2.]), + tensor([1., 0., 0., 0.]), + tensor([1., 0., 1., 1.])) """ num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) @@ -135,9 +175,7 @@ def stat_scores_multiple_classes( fns = torch.zeros((num_classes,), device=pred.device) sups = torch.zeros((num_classes,), device=pred.device) for c in range(num_classes): - tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred, - target=target, - class_index=c) + tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred, target=target, class_index=c) return tps, fps, tns, fns, sups @@ -164,6 +202,14 @@ def accuracy( Return: A Tensor with the classification score. + + Example: + + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([0, 2, 3]) + >>> accuracy(x, y) + tensor(0.6667) + """ tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes) @@ -193,6 +239,16 @@ def confusion_matrix( Return: Tensor, confusion matrix C [num_classes, num_classes ] + + Example: + + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([0, 2, 3]) + >>> confusion_matrix(x, y) + tensor([[0., 1., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 1., 0.], + [0., 0., 0., 1.]]) """ num_classes = get_num_classes(pred, target, None) @@ -229,10 +285,16 @@ def precision_recall( Return: Tensor with precision and recall + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> precision_recall(x, y) + (tensor(1.), tensor(0.8333)) + """ - tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, - target=target, - num_classes=num_classes) + tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes) tps = tps.to(torch.float) fps = fps.to(torch.float) @@ -268,6 +330,14 @@ def precision( Return: Tensor with precision. + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> precision(x, y) + tensor(1.) + """ return precision_recall(pred=pred, target=target, num_classes=num_classes, reduction=reduction)[0] @@ -295,6 +365,13 @@ def recall( Return: Tensor with recall. + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> recall(x, y) + tensor(0.8333) """ return precision_recall(pred=pred, target=target, num_classes=num_classes, reduction=reduction)[1] @@ -329,6 +406,13 @@ def fbeta_score( Return: Tensor with the value of F-score. It is a value between 0-1. + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> fbeta_score(x, y, 0.2) + tensor(0.9877) """ prec, rec = precision_recall(pred=pred, target=target, num_classes=num_classes, @@ -363,6 +447,13 @@ def f1_score( Return: Tensor containing F1-score + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> f1_score(x, y) + tensor(0.8889) """ return fbeta_score(pred=pred, target=target, beta=1., num_classes=num_classes, reduction=reduction) @@ -431,6 +522,19 @@ def roc( Return: [Tensor, Tensor, Tensor]: false-positive rate (fpr), true-positive rate (tpr), thresholds + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> fpr, tpr, thresholds = roc(x,y) + >>> fpr + tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]) + >>> tpr + tensor([0., 0., 0., 1., 1.]) + >>> thresholds + tensor([4, 3, 2, 1, 0]) + """ fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, sample_weight=sample_weight, @@ -473,6 +577,19 @@ def multiclass_roc( Return: [num_classes, Tensor, Tensor, Tensor]: returns roc for each class. number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds + + Example: + + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE + ((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), + (tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), + (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])), + (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500]))) """ num_classes = get_num_classes(pred, target, num_classes) @@ -503,6 +620,19 @@ def precision_recall_curve( Return: [Tensor, Tensor, Tensor]: precision, recall, thresholds + + Example: + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 2, 2]) + >>> precision, recall, thresholds = precision_recall_curve(pred, target) + >>> precision + tensor([0.3333, 0.0000, 0.0000, 1.0000]) + >>> recall + tensor([1., 0., 0., 0.]) + >>> thresholds + tensor([1, 2, 3]) + """ fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, sample_weight=sample_weight, @@ -547,7 +677,24 @@ def multiclass_precision_recall_curve( num_classes: number of classes Return: - [num_classes, Tensor, Tensor, Tensor]: number of classes, precision, recall, thresholds + [num_classes, Tensor, Tensor, Tensor]: number of classes, precision, recall, thresholds + + Example: + + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> nb_classes, precision, recall, thresholds = multiclass_precision_recall_curve(pred, target) + >>> nb_classes + (tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])) + >>> precision + (tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])) + >>> recall + (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])) + >>> thresholds # doctest: +NORMALIZE_WHITESPACE + (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])) """ num_classes = get_num_classes(pred, target, num_classes) @@ -574,6 +721,13 @@ def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True): Return: AUC score (float) + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> auc(x, y) + tensor(4.) """ direction = 1. @@ -635,6 +789,13 @@ def auroc( target: ground-truth labels sample_weight: sample weights pos_label: the label for the positive class (default: 1.) + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> auroc(x, y) + tensor(0.3333) """ @auc_decorator(reorder=True) @@ -650,6 +811,21 @@ def average_precision( sample_weight: Optional[Sequence] = None, pos_label: int = 1., ) -> torch.Tensor: + """ + + Args: + pred: estimated probabilities + target: ground-truth labels + sample_weight: sample weights + pos_label: the label for the positive class (default: 1.) + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> average_precision(x, y) + tensor(0.3333) + """ precision, recall, _ = precision_recall_curve(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) @@ -667,6 +843,26 @@ def dice_score( no_fg_score: float = 0.0, reduction: str = 'elementwise_mean', ) -> torch.Tensor: + """ + Args: + pred: estimated probabilities + target: ground-truth labels + bg: + nan_score: + no_fg_score: + reduction: + + Example: + + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> average_precision(pred, target) + tensor(0.2500) + + """ n_classes = pred.shape[1] bg = (1 - int(bool(bg))) scores = torch.zeros(n_classes - bg, device=pred.device, dtype=torch.float32) diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index e4e6a5112e..e9bf9be1b5 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -351,5 +351,6 @@ def test_dice_score(pred, target, expected): score = dice_score(torch.tensor(pred), torch.tensor(target)) assert score == expected + # example data taken from # https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index cd1996e2e9..424e4c7b94 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -27,7 +27,7 @@ from tests.base import EvalModelTemplate def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir - monkeypatch.setenv('TORCH_HOME', tmpdir) + monkeypatch.setenv('TORCH_HOME', str(tmpdir)) model = EvalModelTemplate()