From e3732789d7af95d6838ba861e67022afed67379b Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 5 Aug 2020 11:32:53 +0200 Subject: [PATCH] Add remaning sklearn metrics (#2562) * added balanced accuracy * added dcg score * added mean absolute error * added mean squared error * fix * added mean squared log error * add median absolute error and r2 score * switch arguments * added mean poisson deviance * add mean gamma deviance and mean tweedie deviance * fix styling * added explained variance score * added cohen kappa score * added hamming, hinge, jaccard * fix styling * update sklearn requirement to newer version * update requirement * fix doctest * fix tests * added balanced accuracy * added dcg score * added mean absolute error * added mean squared error * fix * added mean squared log error * add median absolute error and r2 score * switch arguments * added mean poisson deviance * add mean gamma deviance and mean tweedie deviance * fix styling * added explained variance score * added cohen kappa score * added hamming, hinge, jaccard * fix styling * update sklearn requirement to newer version * fix doctest * fix tests * fix doctest * fix failing docs * fix test * trying to fix errors * Apply suggestions from code review * format Co-authored-by: Nicki Skafte Co-authored-by: Jirka Borovec Co-authored-by: Jirka Borovec --- CHANGELOG.md | 2 + docs/source/metrics.rst | 90 +++ pytorch_lightning/metrics/sklearns.py | 908 +++++++++++++++++++++++++- requirements/extra.txt | 4 +- tests/metrics/test_sklearn.py | 79 ++- 5 files changed, 1070 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2800dfc792..a56caac77a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for Mean in DDP Sync ([#2568](https://github.com/PyTorchLightning/pytorch-lightning/pull/2568)) +- Added remaining `sklearn` metrics: `AveragePrecision`, `BalancedAccuracy`, `CohenKappaScore`, `DCG`, `Hamming`, `Hinge`, `Jaccard`, `MeanAbsoluteError`, `MeanSquaredError`, `MeanSquaredLogError`, `MedianAbsoluteError`, `R2Score`, `MeanPoissonDeviance`, `MeanGammaDeviance`, `MeanTweedieDeviance`, `ExplainedVariance` ([#2562](https://github.com/PyTorchLightning/pytorch-lightning/pull/2562)) + ### Changed - Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594)) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index a32362af23..7102ac059e 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -484,6 +484,17 @@ AveragePrecision (sk) .. autofunction:: pytorch_lightning.metrics.sklearns.AveragePrecision :noindex: +BalancedAccuracy (sk) +^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.sklearns.BalancedAccuracy + :noindex: + +CohenKappaScore (sk) +^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.sklearns.CohenKappaScore + :noindex: ConfusionMatrix (sk) ^^^^^^^^^^^^^^^^^^^^ @@ -491,6 +502,12 @@ ConfusionMatrix (sk) .. autofunction:: pytorch_lightning.metrics.sklearns.ConfusionMatrix :noindex: +DCG (sk) +^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.sklearns.DCG + :noindex: + F1 (sk) ^^^^^^^ @@ -503,6 +520,24 @@ FBeta (sk) .. autofunction:: pytorch_lightning.metrics.sklearns.FBeta :noindex: +Hamming (sk) +^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.sklearns.Hamming + :noindex: + +Hinge (sk) +^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.sklearns.Hinge + :noindex: + +Jaccard (sk) +^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.sklearns.Jaccard + :noindex: + Precision (sk) ^^^^^^^^^^^^^^ @@ -532,3 +567,58 @@ AUROC (sk) .. autofunction:: pytorch_lightning.metrics.sklearns.AUROC :noindex: + +ExplainedVariance (sk) +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.sklearns.ExplainedVariance + :noindex: + +MeanAbsoluteError (sk) +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.sklearns.MeanAbsoluteError + :noindex: + +MeanSquaredError (sk) +^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.sklearns.MeanSquaredError + :noindex: + +MeanSquaredLogError (sk) +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.sklearns.MeanSquaredLogError + :noindex: + +MedianAbsoluteError (sk) +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.sklearns.MedianAbsoluteError + :noindex: + +R2Score (sk) +^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.sklearns.R2Score + :noindex: + +MeanPoissonDeviance (sk) +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.sklearns.MeanPoissonDeviance + :noindex: + +MeanGammaDeviance (sk) +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.sklearns.MeanGammaDeviance + :noindex: + +MeanTweedieDeviance (sk) +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: pytorch_lightning.metrics.sklearns.MeanTweedieDeviance + :noindex: + diff --git a/pytorch_lightning/metrics/sklearns.py b/pytorch_lightning/metrics/sklearns.py index 004649c293..df8f826f69 100644 --- a/pytorch_lightning/metrics/sklearns.py +++ b/pytorch_lightning/metrics/sklearns.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union, Sequence +from typing import Any, Optional, Union, Sequence, List import numpy as np import torch @@ -29,6 +29,7 @@ class SklearnMetric(NumpyMetric): Note: The order of targets and predictions may be different from the order typically used in PyTorch """ + def __init__( self, metric_name: str, @@ -91,6 +92,7 @@ class Accuracy(SklearnMetric): tensor([0.7500]) """ + def __init__( self, normalize: bool = True, @@ -115,7 +117,7 @@ class Accuracy(SklearnMetric): self, y_pred: np.ndarray, y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None + sample_weight: Optional[np.ndarray] = None, ) -> float: """ Computes the accuracy @@ -146,6 +148,7 @@ class AUC(SklearnMetric): >>> metric(y_pred, y_true) tensor([4.]) """ + def __init__( self, reduce_group: Any = group.WORLD, @@ -183,6 +186,7 @@ class AveragePrecision(SklearnMetric): Calculates the average precision (AP) score. """ + def __init__( self, average: Optional[str] = 'macro', @@ -216,7 +220,7 @@ class AveragePrecision(SklearnMetric): self, y_score: np.ndarray, y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None + sample_weight: Optional[np.ndarray] = None, ) -> float: """ Args: @@ -232,6 +236,122 @@ class AveragePrecision(SklearnMetric): sample_weight=sample_weight) +class BalancedAccuracy(SklearnMetric): + """ Compute the balanced accuracy score + + Warning: + Every metric call will cause a GPU synchronization, which may slow down your code + + Example: + + >>> y_pred = torch.tensor([0, 0, 0, 1]) + >>> y_true = torch.tensor([0, 0, 1, 1]) + >>> metric = BalancedAccuracy() + >>> metric(y_pred, y_true) + tensor([0.7500]) + + """ + + def __init__( + self, + adjusted: bool = False, + reduce_group: Any = group.WORLD, + reduce_op: Any = ReduceOp.SUM, + ): + """ + Args: + adjusted: If ``True``, the result sis adjusted for chance, such that random performance + corresponds to 0 and perfect performance corresponds to 1 + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('balanced_accuracy_score', + reduce_group=reduce_group, + reduce_op=reduce_op, + adjusted=adjusted) + + def forward( + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, + ) -> float: + """ + Args: + y_pred: the array containing the predictions (already in categorical form) + y_true: the array containing the targets (in categorical form) + sample_weight: Sample weights. + + Return: + balanced accuracy score + + """ + return super().forward(y_true=y_true, + y_pred=y_pred, + sample_weight=sample_weight) + + +class CohenKappaScore(SklearnMetric): + """ + Calculates Cohens kappa: a statitic that measures inter-annotator agreement + + Example: + + >>> y_pred = torch.tensor([1, 2, 0, 2]) + >>> y_true = torch.tensor([2, 2, 2, 1]) + >>> metric = CohenKappaScore() + >>> metric(y_pred, y_true) + tensor([-0.3333]) + + """ + + def __init__( + self, + labels: Optional[Sequence] = None, + weights: Optional[str] = None, + reduce_group: Any = group.WORLD, + reduce_op: Any = ReduceOp.SUM, + ): + """ + Args: + labels: List of labels to index the matrix. This may be used to reorder + or select a subset of labels. + If none is given, those that appear at least once + in ``y1`` or ``y2`` are used in sorted order. + weights: string indicating weightning type used in scoring. None + means no weighting, string ``linear`` means linear weighted + and ``quadratic`` means quadratic weighted + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('cohen_kappa_score', + reduce_group=reduce_group, + reduce_op=reduce_op, + labels=labels, + weights=weights) + + def forward( + self, + y1: np.ndarray, + y2: np.ndarray, + sample_weight: Optional[np.ndarray] = None, + ) -> float: + """ + Args: + y_1: Labels assigned by first annotator + y_2: Labels assigned by second annotator + sample_weight: Sample weights. + + Return: + Cohens kappa score + """ + return super().forward(y1=y1, y2=y2, sample_weight=sample_weight) + + class ConfusionMatrix(SklearnMetric): """ Compute confusion matrix to evaluate the accuracy of a classification @@ -250,8 +370,10 @@ class ConfusionMatrix(SklearnMetric): [0., 1., 1.]]) """ + def __init__( - self, labels: Optional[Sequence] = None, + self, + labels: Optional[Sequence] = None, reduce_group: Any = group.WORLD, reduce_op: Any = ReduceOp.SUM, ): @@ -284,6 +406,68 @@ class ConfusionMatrix(SklearnMetric): return super().forward(y_pred=y_pred, y_true=y_true) +class DCG(SklearnMetric): + """ Compute discounted cumulative gain + + Warning: + Every metric call will cause a GPU synchronization, which may slow down your code + + Example: + + >>> y_score = torch.tensor([[.1, .2, .3, 4, 70]]) + >>> y_true = torch.tensor([[10, 0, 0, 1, 5]]) + >>> metric = DCG() + >>> metric(y_score, y_true) + tensor([9.4995]) + """ + + def __init__( + self, + k: Optional[int] = None, + log_base: float = 2, + ignore_ties: bool = False, + reduce_group: Any = group.WORLD, + reduce_op: Any = ReduceOp.SUM, + ): + """ + Args: + k: only consider the hightest k score in the ranking + log_base: base of the logarithm used for the discount + ignore_ties: If ``True``, assume there are no ties in y_score for efficiency gains + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('dcg_score', + reduce_group=reduce_group, + reduce_op=reduce_op, + k=k, + log_base=log_base, + ignore_ties=ignore_ties) + + def forward( + self, + y_score: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, + ) -> float: + """ + Args: + y_score: target scores, either probability estimates, confidence values + or or non-thresholded measure of decisions + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + Return: + DCG score + + """ + return super().forward(y_true=y_true, + y_score=y_score, + sample_weight=sample_weight) + + class F1(SklearnMetric): r""" Compute the F1 score, also known as balanced F-score or F-measure @@ -313,7 +497,8 @@ class F1(SklearnMetric): """ def __init__( - self, labels: Optional[Sequence] = None, + self, + labels: Optional[Sequence] = None, pos_label: Union[str, int] = 1, average: Optional[str] = 'macro', reduce_group: Any = group.WORLD, @@ -365,7 +550,7 @@ class F1(SklearnMetric): self, y_pred: np.ndarray, y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None + sample_weight: Optional[np.ndarray] = None, ) -> Union[np.ndarray, float]: """ Args: @@ -460,7 +645,7 @@ class FBeta(SklearnMetric): self, y_pred: np.ndarray, y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None + sample_weight: Optional[np.ndarray] = None, ) -> Union[np.ndarray, float]: """ Args: @@ -476,6 +661,193 @@ class FBeta(SklearnMetric): return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight) +class Hamming(SklearnMetric): + """ + Computes the average hamming loss + + Example: + + >>> y_pred = torch.tensor([0, 1, 2, 3]) + >>> y_true = torch.tensor([1, 1, 2, 3]) + >>> metric = Hamming() + >>> metric(y_pred, y_true) + tensor([0.2500]) + + """ + + def __init__( + self, + reduce_group: Any = group.WORLD, + reduce_op: Any = ReduceOp.SUM, + ): + """ + Args: + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + + """ + super().__init__('hamming_loss', + reduce_group=reduce_group, + reduce_op=reduce_op) + + def forward( + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, + ) -> Union[np.ndarray, float]: + """ + Args: + y_pred : Estimated targets as returned by a classifier. + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + Return: + Average hamming loss + + """ + return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight) + + +class Hinge(SklearnMetric): + """ + Computes the average hinge loss + + Example: + + >>> pred_decision = torch.tensor([-2.17, -0.97, -0.19, -0.43]) + >>> y_true = torch.tensor([1, 1, 0, 0]) + >>> metric = Hinge() + >>> metric(pred_decision, y_true) + tensor([1.6300]) + + """ + + def __init__( + self, + labels: Optional[Sequence] = None, + reduce_group: Any = group.WORLD, + reduce_op: Any = ReduceOp.SUM, + ): + """ + Args: + labels: Integer array of labels. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('hinge_loss', + reduce_group=reduce_group, + reduce_op=reduce_op, + labels=labels) + + def forward( + self, + pred_decision: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, + ) -> float: + """ + Args: + pred_decision : Predicted decisions + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + Return: + Average hinge loss + + """ + return super().forward(pred_decision=pred_decision, + y_true=y_true, + sample_weight=sample_weight) + + +class Jaccard(SklearnMetric): + """ + Calculates jaccard similarity coefficient score + + Example: + + >>> y_pred = torch.tensor([1, 1, 1]) + >>> y_true = torch.tensor([0, 1, 1]) + >>> metric = Jaccard() + >>> metric(y_pred, y_true) + tensor([0.3333]) + + """ + + def __init__( + self, + labels: Optional[Sequence] = None, + pos_label: Union[str, int] = 1, + average: Optional[str] = 'macro', + reduce_group: Any = group.WORLD, + reduce_op: Any = ReduceOp.SUM, + ): + """ + Args: + labels: Integer array of labels. + pos_label: The class to report if ``average='binary'``. + average: This parameter is required for multiclass/multilabel targets. + If ``None``, the scores for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + + * ``'binary'``: + Only report results for the class specified by ``pos_label``. + This is applicable only if targets (``y_{true,pred}``) are binary. + * ``'micro'``: + Calculate metrics globally by counting the total true positives, + false negatives and false positives. + * ``'macro'``: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + * ``'weighted'``: + Calculate metrics for each label, and find their average, weighted + by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance; it can result in an + F-score that is not between precision and recall. + * ``'samples'``: + Calculate metrics for each instance, and find their average (only + meaningful for multilabel classification where this differs from + :func:`accuracy_score`). + + Note that if ``pos_label`` is given in binary classification with + `average != 'binary'`, only that positive class is reported. This + behavior is deprecated and will change in version 0.18. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('jaccard_score', + reduce_group=reduce_group, + reduce_op=reduce_op, + labels=labels, + pos_label=pos_label, + average=average) + + def forward( + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, + ) -> Union[np.ndarray, float]: + """ + Args: + y_pred : Estimated targets as returned by a classifier. + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + Return: + Jaccard similarity score + + """ + return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight) + + class Precision(SklearnMetric): """ Compute the precision @@ -695,7 +1067,7 @@ class PrecisionRecallCurve(SklearnMetric): self, probas_pred: np.ndarray, y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None + sample_weight: Optional[np.ndarray] = None, ) -> Union[np.ndarray, float]: """ Args: @@ -729,6 +1101,9 @@ class ROC(SklearnMetric): Note: this implementation is restricted to the binary classification task. + Warning: + Every metric call will cause a GPU synchronization, which may slow down your code + Example: >>> y_pred = torch.tensor([0, 1, 2, 3]) @@ -769,7 +1144,7 @@ class ROC(SklearnMetric): self, y_score: np.ndarray, y_true: np.ndarray, - sample_weight: Optional[np.ndarray] = None + sample_weight: Optional[np.ndarray] = None, ) -> Union[np.ndarray, float]: """ Args: @@ -802,6 +1177,9 @@ class AUROC(SklearnMetric): this implementation is restricted to the binary classification task or multilabel classification task in label indicator format. + Warning: + Every metric call will cause a GPU synchronization, which may slow down your code + """ def __init__( @@ -851,3 +1229,515 @@ class AUROC(SklearnMetric): """ return super().forward(y_score=y_score, y_true=y_true, sample_weight=sample_weight) + + +class ExplainedVariance(SklearnMetric): + """ + Calculates explained variance score + + Warning: + Every metric call will cause a GPU synchronization, which may slow down your code + + Example: + + >>> y_pred = torch.tensor([2.5, 0.0, 2, 8]) + >>> y_true = torch.tensor([3, -0.5, 2, 7]) + >>> metric = ExplainedVariance() + >>> metric(y_pred, y_true) + tensor([0.9572]) + """ + + def __init__( + self, + multioutput: Optional[Union[str, List[float]]] = 'variance_weighted', + reduce_group: Any = group.WORLD, + reduce_op: Any = ReduceOp.SUM, + ): + """ + Args: + multioutput: either one of the strings [‘raw_values’, ‘uniform_average’, 'variance_weighted'] + or an array with shape (n_outputs,) that defines how multiple + output values should be aggregated. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('explained_variance_score', + reduce_group=reduce_group, + reduce_op=reduce_op, + multioutput=multioutput) + + def forward( + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, + ): + """ + Args: + y_pred: Estimated target values + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + Return: + Explained variance score + + """ + return super().forward(y_true=y_true, y_pred=y_pred, + sample_weight=sample_weight) + + +class MeanAbsoluteError(SklearnMetric): + """ + Compute absolute error regression loss + + Warning: + Every metric call will cause a GPU synchronization, which may slow down your code + + Example: + + >>> y_pred = torch.tensor([2.5, 0.0, 2, 8]) + >>> y_true = torch.tensor([3, -0.5, 2, 7]) + >>> metric = MeanAbsoluteError() + >>> metric(y_pred, y_true) + tensor([0.5000]) + + """ + + def __init__( + self, + multioutput: Optional[Union[str, List[float]]] = 'uniform_average', + reduce_group: Any = group.WORLD, + reduce_op: Any = ReduceOp.SUM, + ): + """ + Args: + multioutput: either one of the strings [‘raw_values’, ‘uniform_average’] + or an array with shape (n_outputs,) that defines how multiple + output values should be aggregated. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('mean_absolute_error', + reduce_group=reduce_group, + reduce_op=reduce_op, + multioutput=multioutput) + + def forward(self, y_pred: np.ndarray, y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None): + """ + Args: + y_pred: Estimated target values + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + Return: + Mean absolute error + + """ + return super().forward(y_true=y_true, + y_pred=y_pred, + sample_weight=sample_weight) + + +class MeanSquaredError(SklearnMetric): + """ + Compute mean squared error loss + + Warning: + Every metric call will cause a GPU synchronization, which may slow down your code + + Example: + + >>> y_pred = torch.tensor([2.5, 0.0, 2, 8]) + >>> y_true = torch.tensor([3, -0.5, 2, 7]) + >>> metric = MeanSquaredError() + >>> metric(y_pred, y_true) + tensor([0.3750]) + >>> metric = MeanSquaredError(squared=True) + >>> metric(y_pred, y_true) + tensor([0.6124]) + + """ + + def __init__( + self, + multioutput: Optional[Union[str, List[float]]] = 'uniform_average', + squared: bool = False, + reduce_group: Any = group.WORLD, + reduce_op: Any = ReduceOp.SUM, + ): + """ + Args: + multioutput: either one of the strings [‘raw_values’, ‘uniform_average’] + or an array with shape (n_outputs,) that defines how multiple + output values should be aggregated. + squared: if ``True`` returns the mse value else the rmse value + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('mean_squared_error', + reduce_group=reduce_group, + reduce_op=reduce_op, + multioutput=multioutput) + self.squared = squared + + def forward( + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, + ): + """ + Args: + y_pred: Estimated target values + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + Return: + Mean squared error + + """ + mse = super().forward(y_true=y_true, y_pred=y_pred, + sample_weight=sample_weight) + if self.squared: + mse = np.sqrt(mse) + return mse + + +class MeanSquaredLogError(SklearnMetric): + """ + Calculates the mean squared log error + + Warning: + Every metric call will cause a GPU synchronization, which may slow down your code + + Example: + + >>> y_pred = torch.tensor([2.5, 5, 4, 8]) + >>> y_true = torch.tensor([3, 5, 2.5, 7]) + >>> metric = MeanSquaredLogError() + >>> metric(y_pred, y_true) + tensor([0.0397]) + """ + + def __init__( + self, + multioutput: Optional[Union[str, List[float]]] = 'uniform_average', + reduce_group: Any = group.WORLD, + reduce_op: Any = ReduceOp.SUM, + ): + """ + Args: + multioutput: either one of the strings [‘raw_values’, ‘uniform_average’] + or an array with shape (n_outputs,) that defines how multiple + output values should be aggregated. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('mean_squared_log_error', + reduce_group=reduce_group, + reduce_op=reduce_op, + multioutput=multioutput) + + def forward( + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, + ): + """ + Args: + y_pred: Estimated target values + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + Return: + Mean squared log error + + """ + return super().forward(y_true=y_true, y_pred=y_pred, + sample_weight=sample_weight) + + +class MedianAbsoluteError(SklearnMetric): + """ + Calculates the median absolute error + + Warning: + Every metric call will cause a GPU synchronization, which may slow down your code + + Example: + + >>> y_pred = torch.tensor([2.5, 0.0, 2, 8]) + >>> y_true = torch.tensor([3, -0.5, 2, 7]) + >>> metric = MedianAbsoluteError() + >>> metric(y_pred, y_true) + tensor([0.5000]) + """ + + def __init__( + self, + multioutput: Optional[Union[str, List[float]]] = 'uniform_average', + reduce_group: Any = group.WORLD, + reduce_op: Any = ReduceOp.SUM, + ): + """ + Args: + multioutput: either one of the strings [‘raw_values’, ‘uniform_average’] + or an array with shape (n_outputs,) that defines how multiple + output values should be aggregated. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('median_absolute_error', + reduce_group=reduce_group, + reduce_op=reduce_op, + multioutput=multioutput) + + def forward(self, y_pred: np.ndarray, y_true: np.ndarray): + """ + Args: + y_pred: Estimated target values + y_true: Ground truth (correct) target values. + + Return: + Median absolute error + + """ + return super().forward(y_true=y_true, y_pred=y_pred) + + +class R2Score(SklearnMetric): + """ + Calculates the R^2 score also known as coefficient of determination + + Warning: + Every metric call will cause a GPU synchronization, which may slow down your code + + Example: + + >>> y_pred = torch.tensor([2.5, 0.0, 2, 8]) + >>> y_true = torch.tensor([3, -0.5, 2, 7]) + >>> metric = R2Score() + >>> metric(y_pred, y_true) + tensor([0.9486]) + """ + + def __init__( + self, + multioutput: Optional[Union[str, List[float]]] = 'uniform_average', + reduce_group: Any = group.WORLD, + reduce_op: Any = ReduceOp.SUM, + ): + """ + Args: + multioutput: either one of the strings [‘raw_values’, ‘uniform_average’, 'variance_weighted'] + or an array with shape (n_outputs,) that defines how multiple + output values should be aggregated. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('r2_score', + reduce_group=reduce_group, + reduce_op=reduce_op, + multioutput=multioutput) + + def forward( + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, + ): + """ + Args: + y_pred: Estimated target values + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + Return: + R^2 score + + """ + return super().forward(y_true=y_true, y_pred=y_pred, + sample_weight=sample_weight) + + +class MeanPoissonDeviance(SklearnMetric): + """ + Calculates the mean poisson deviance regression loss + + Warning: + Every metric call will cause a GPU synchronization, which may slow down your code + + Example: + + >>> y_pred = torch.tensor([2, 0.5, 1, 4]) + >>> y_true = torch.tensor([0.5, 0.5, 2., 2.]) + >>> metric = MeanPoissonDeviance() + >>> metric(y_pred, y_true) + tensor([0.9034]) + """ + + def __init__( + self, + reduce_group: Any = group.WORLD, + reduce_op: Any = ReduceOp.SUM, + ): + """ + Args: + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('mean_poisson_deviance', + reduce_group=reduce_group, + reduce_op=reduce_op) + + def forward( + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, + ): + """ + Args: + y_pred: Estimated target values + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + Return: + Mean possion deviance + + """ + return super().forward(y_true=y_true, y_pred=y_pred, + sample_weight=sample_weight) + + +class MeanGammaDeviance(SklearnMetric): + """ + Calculates the mean gamma deviance regression loss + + Warning: + Every metric call will cause a GPU synchronization, which may slow down your code + + Example: + + >>> y_pred = torch.tensor([0.5, 0.5, 2., 2.]) + >>> y_true = torch.tensor([2, 0.5, 1, 4]) + >>> metric = MeanGammaDeviance() + >>> metric(y_pred, y_true) + tensor([1.0569]) + """ + + def __init__( + self, + reduce_group: Any = group.WORLD, + reduce_op: Any = ReduceOp.SUM, + ): + """ + Args: + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('mean_gamma_deviance', + reduce_group=reduce_group, + reduce_op=reduce_op) + + def forward( + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, + ): + """ + Args: + y_pred: Estimated target values + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + Return: + Mean gamma deviance + + """ + return super().forward(y_true=y_true, y_pred=y_pred, + sample_weight=sample_weight) + + +class MeanTweedieDeviance(SklearnMetric): + """ + Calculates the mean tweedie deviance regression loss + + Warning: + Every metric call will cause a GPU synchronization, which may slow down your code + + Example: + + >>> y_pred = torch.tensor([2, 0.5, 1, 4]) + >>> y_true = torch.tensor([0.5, 0.5, 2., 2.]) + >>> metric = MeanTweedieDeviance() + >>> metric(y_pred, y_true) + tensor([1.8125]) + """ + + def __init__( + self, + power: float = 0, + reduce_group: Any = group.WORLD, + reduce_op: Any = ReduceOp.SUM, + ): + """ + Args: + power: tweedie power parameter: + + * power < 0: Extreme stable distribution. Requires: y_pred > 0. + * power = 0 : Normal distribution, output corresponds to mean_squared_error. + y_true and y_pred can be any real numbers. + * power = 1 : Poisson distribution. Requires: y_true >= 0 and y_pred > 0. + * 1 < power < 2 : Compound Poisson distribution. Requires: y_true >= 0 and y_pred > 0. + * power = 2 : Gamma distribution. Requires: y_true > 0 and y_pred > 0. + * power = 3 : Inverse Gaussian distribution. Requires: y_true > 0 and y_pred > 0. + * otherwise : Positive stable distribution. Requires: y_true > 0 and y_pred > 0. + + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('mean_tweedie_deviance', + reduce_group=reduce_group, + reduce_op=reduce_op, + power=power) + + def forward( + self, + y_pred: np.ndarray, + y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None, + ): + """ + Args: + y_pred: Estimated target values + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + Return: + Mean tweedie deviance + + """ + return super().forward(y_true=y_true, y_pred=y_pred, + sample_weight=sample_weight) diff --git a/requirements/extra.txt b/requirements/extra.txt index 31ea41c083..2b8854678d 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -10,7 +10,7 @@ matplotlib>=3.1.1 horovod>=0.19.2 omegaconf>=2.0.0 # scipy>=0.13.3 -scikit-learn>=0.20.0 +scikit-learn>=0.22.2 torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility onnx>=1.7.0 -onnxruntime>=1.3.0 \ No newline at end of file +onnxruntime>=1.3.0 diff --git a/tests/metrics/test_sklearn.py b/tests/metrics/test_sklearn.py index f9f8386656..bef5a4ffe0 100644 --- a/tests/metrics/test_sklearn.py +++ b/tests/metrics/test_sklearn.py @@ -16,21 +16,51 @@ from sklearn.metrics import ( precision_recall_curve as sk_precision_recall_curve, roc_curve as sk_roc_curve, roc_auc_score as sk_roc_auc_score, + balanced_accuracy_score as sk_balanced_accuracy_score, + dcg_score as sk_dcg_score, + mean_absolute_error as sk_mean_absolute_error, + mean_squared_error as sk_mean_squared_error, + mean_squared_log_error as sk_mean_squared_log_error, + median_absolute_error as sk_median_absolute_error, + r2_score as sk_r2_score, + mean_poisson_deviance as sk_mean_poisson_deviance, + mean_gamma_deviance as sk_mean_gamma_deviance, + mean_tweedie_deviance as sk_mean_tweedie_deviance, + explained_variance_score as sk_explained_variance_score, + cohen_kappa_score as sk_cohen_kappa_score, + hamming_loss as sk_hamming_loss, + hinge_loss as sk_hinge_loss, + jaccard_score as sk_jaccard_score ) from pytorch_lightning.metrics.converters import _convert_to_numpy from pytorch_lightning.metrics.sklearns import ( Accuracy, - AveragePrecision, AUC, + AveragePrecision, + BalancedAccuracy, ConfusionMatrix, + CohenKappaScore, + DCG, F1, FBeta, + Hamming, + Hinge, + Jaccard, Precision, Recall, PrecisionRecallCurve, ROC, - AUROC + AUROC, + MeanAbsoluteError, + MeanSquaredError, + MeanSquaredLogError, + MedianAbsoluteError, + R2Score, + MeanPoissonDeviance, + MeanGammaDeviance, + MeanTweedieDeviance, + ExplainedVariance, ) from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -86,6 +116,51 @@ def _xy_only(func): {'y_score': torch.rand(size=(128,)), 'y_true': torch.randint(2, size=(128,))}, id='AUROC'), + pytest.param(BalancedAccuracy(), sk_balanced_accuracy_score, + {'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))}, + id='BalancedAccuracy'), + pytest.param(DCG(), sk_dcg_score, + {'y_score': torch.rand(size=(128, 3)), 'y_true': torch.randint(3, size=(128, 3))}, + id='DCG'), + pytest.param(ExplainedVariance(), sk_explained_variance_score, + {'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))}, + id='ExplainedVariance'), + pytest.param(MeanAbsoluteError(), sk_mean_absolute_error, + {'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))}, + id='MeanAbsolutError'), + pytest.param(MeanSquaredError(), sk_mean_squared_error, + {'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))}, + id='MeanSquaredError'), + pytest.param(MeanSquaredLogError(), sk_mean_squared_log_error, + {'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))}, + id='MeanSquaredLogError'), + pytest.param(MedianAbsoluteError(), sk_median_absolute_error, + {'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))}, + id='MedianAbsoluteError'), + pytest.param(R2Score(), sk_r2_score, + {'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))}, + id='R2Score'), + pytest.param(MeanPoissonDeviance(), sk_mean_poisson_deviance, + {'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))}, + id='MeanPoissonDeviance'), + pytest.param(MeanGammaDeviance(), sk_mean_gamma_deviance, + {'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))}, + id='MeanGammaDeviance'), + pytest.param(MeanTweedieDeviance(), sk_mean_tweedie_deviance, + {'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))}, + id='MeanTweedieDeviance'), + pytest.param(CohenKappaScore(), sk_cohen_kappa_score, + {'y1': torch.randint(3, size=(128,)), 'y2': torch.randint(3, size=(128,))}, + id='CohenKappaScore'), + pytest.param(Hamming(), sk_hamming_loss, + {'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))}, + id='Hamming'), + pytest.param(Hinge(), sk_hinge_loss, + {'pred_decision': torch.randn(size=(128,)), 'y_true': torch.randint(2, size=(128,))}, + id='Hinge'), + pytest.param(Jaccard(average='macro'), partial(sk_jaccard_score, average='macro'), + {'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))}, + id='Jaccard') ]) def test_sklearn_metric(metric_class, sklearn_func, inputs): numpy_inputs = apply_to_collection(inputs, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)