Add remaning sklearn metrics (#2562)

* added balanced accuracy

* added dcg score

* added mean absolute error

* added mean squared error

* fix

* added mean squared log error

* add median absolute error and r2 score

* switch arguments

* added mean poisson deviance

* add mean gamma deviance and mean tweedie deviance

* fix styling

* added explained variance score

* added cohen kappa score

* added hamming, hinge, jaccard

* fix styling

* update sklearn requirement to newer version

* update requirement

* fix doctest

* fix tests

* added balanced accuracy

* added dcg score

* added mean absolute error

* added mean squared error

* fix

* added mean squared log error

* add median absolute error and r2 score

* switch arguments

* added mean poisson deviance

* add mean gamma deviance and mean tweedie deviance

* fix styling

* added explained variance score

* added cohen kappa score

* added hamming, hinge, jaccard

* fix styling

* update sklearn requirement to newer version

* fix doctest

* fix tests

* fix doctest

* fix failing docs

* fix test

* trying to fix errors

* Apply suggestions from code review

* format

Co-authored-by: Nicki Skafte <nugginea@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
This commit is contained in:
Nicki Skafte 2020-08-05 11:32:53 +02:00 committed by GitHub
parent ad0f1194aa
commit e3732789d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 1070 additions and 13 deletions

View File

@ -29,6 +29,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for Mean in DDP Sync ([#2568](https://github.com/PyTorchLightning/pytorch-lightning/pull/2568))
- Added remaining `sklearn` metrics: `AveragePrecision`, `BalancedAccuracy`, `CohenKappaScore`, `DCG`, `Hamming`, `Hinge`, `Jaccard`, `MeanAbsoluteError`, `MeanSquaredError`, `MeanSquaredLogError`, `MedianAbsoluteError`, `R2Score`, `MeanPoissonDeviance`, `MeanGammaDeviance`, `MeanTweedieDeviance`, `ExplainedVariance` ([#2562](https://github.com/PyTorchLightning/pytorch-lightning/pull/2562))
### Changed
- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594))

View File

@ -484,6 +484,17 @@ AveragePrecision (sk)
.. autofunction:: pytorch_lightning.metrics.sklearns.AveragePrecision
:noindex:
BalancedAccuracy (sk)
^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.BalancedAccuracy
:noindex:
CohenKappaScore (sk)
^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.CohenKappaScore
:noindex:
ConfusionMatrix (sk)
^^^^^^^^^^^^^^^^^^^^
@ -491,6 +502,12 @@ ConfusionMatrix (sk)
.. autofunction:: pytorch_lightning.metrics.sklearns.ConfusionMatrix
:noindex:
DCG (sk)
^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.DCG
:noindex:
F1 (sk)
^^^^^^^
@ -503,6 +520,24 @@ FBeta (sk)
.. autofunction:: pytorch_lightning.metrics.sklearns.FBeta
:noindex:
Hamming (sk)
^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.Hamming
:noindex:
Hinge (sk)
^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.Hinge
:noindex:
Jaccard (sk)
^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.Jaccard
:noindex:
Precision (sk)
^^^^^^^^^^^^^^
@ -532,3 +567,58 @@ AUROC (sk)
.. autofunction:: pytorch_lightning.metrics.sklearns.AUROC
:noindex:
ExplainedVariance (sk)
^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.ExplainedVariance
:noindex:
MeanAbsoluteError (sk)
^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanAbsoluteError
:noindex:
MeanSquaredError (sk)
^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanSquaredError
:noindex:
MeanSquaredLogError (sk)
^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanSquaredLogError
:noindex:
MedianAbsoluteError (sk)
^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.MedianAbsoluteError
:noindex:
R2Score (sk)
^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.R2Score
:noindex:
MeanPoissonDeviance (sk)
^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanPoissonDeviance
:noindex:
MeanGammaDeviance (sk)
^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanGammaDeviance
:noindex:
MeanTweedieDeviance (sk)
^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanTweedieDeviance
:noindex:

File diff suppressed because it is too large Load Diff

View File

@ -10,7 +10,7 @@ matplotlib>=3.1.1
horovod>=0.19.2
omegaconf>=2.0.0
# scipy>=0.13.3
scikit-learn>=0.20.0
scikit-learn>=0.22.2
torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility
onnx>=1.7.0
onnxruntime>=1.3.0
onnxruntime>=1.3.0

View File

@ -16,21 +16,51 @@ from sklearn.metrics import (
precision_recall_curve as sk_precision_recall_curve,
roc_curve as sk_roc_curve,
roc_auc_score as sk_roc_auc_score,
balanced_accuracy_score as sk_balanced_accuracy_score,
dcg_score as sk_dcg_score,
mean_absolute_error as sk_mean_absolute_error,
mean_squared_error as sk_mean_squared_error,
mean_squared_log_error as sk_mean_squared_log_error,
median_absolute_error as sk_median_absolute_error,
r2_score as sk_r2_score,
mean_poisson_deviance as sk_mean_poisson_deviance,
mean_gamma_deviance as sk_mean_gamma_deviance,
mean_tweedie_deviance as sk_mean_tweedie_deviance,
explained_variance_score as sk_explained_variance_score,
cohen_kappa_score as sk_cohen_kappa_score,
hamming_loss as sk_hamming_loss,
hinge_loss as sk_hinge_loss,
jaccard_score as sk_jaccard_score
)
from pytorch_lightning.metrics.converters import _convert_to_numpy
from pytorch_lightning.metrics.sklearns import (
Accuracy,
AveragePrecision,
AUC,
AveragePrecision,
BalancedAccuracy,
ConfusionMatrix,
CohenKappaScore,
DCG,
F1,
FBeta,
Hamming,
Hinge,
Jaccard,
Precision,
Recall,
PrecisionRecallCurve,
ROC,
AUROC
AUROC,
MeanAbsoluteError,
MeanSquaredError,
MeanSquaredLogError,
MedianAbsoluteError,
R2Score,
MeanPoissonDeviance,
MeanGammaDeviance,
MeanTweedieDeviance,
ExplainedVariance,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
@ -86,6 +116,51 @@ def _xy_only(func):
{'y_score': torch.rand(size=(128,)),
'y_true': torch.randint(2, size=(128,))},
id='AUROC'),
pytest.param(BalancedAccuracy(), sk_balanced_accuracy_score,
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
id='BalancedAccuracy'),
pytest.param(DCG(), sk_dcg_score,
{'y_score': torch.rand(size=(128, 3)), 'y_true': torch.randint(3, size=(128, 3))},
id='DCG'),
pytest.param(ExplainedVariance(), sk_explained_variance_score,
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
id='ExplainedVariance'),
pytest.param(MeanAbsoluteError(), sk_mean_absolute_error,
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
id='MeanAbsolutError'),
pytest.param(MeanSquaredError(), sk_mean_squared_error,
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
id='MeanSquaredError'),
pytest.param(MeanSquaredLogError(), sk_mean_squared_log_error,
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
id='MeanSquaredLogError'),
pytest.param(MedianAbsoluteError(), sk_median_absolute_error,
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
id='MedianAbsoluteError'),
pytest.param(R2Score(), sk_r2_score,
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
id='R2Score'),
pytest.param(MeanPoissonDeviance(), sk_mean_poisson_deviance,
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
id='MeanPoissonDeviance'),
pytest.param(MeanGammaDeviance(), sk_mean_gamma_deviance,
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
id='MeanGammaDeviance'),
pytest.param(MeanTweedieDeviance(), sk_mean_tweedie_deviance,
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
id='MeanTweedieDeviance'),
pytest.param(CohenKappaScore(), sk_cohen_kappa_score,
{'y1': torch.randint(3, size=(128,)), 'y2': torch.randint(3, size=(128,))},
id='CohenKappaScore'),
pytest.param(Hamming(), sk_hamming_loss,
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
id='Hamming'),
pytest.param(Hinge(), sk_hinge_loss,
{'pred_decision': torch.randn(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
id='Hinge'),
pytest.param(Jaccard(average='macro'), partial(sk_jaccard_score, average='macro'),
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
id='Jaccard')
])
def test_sklearn_metric(metric_class, sklearn_func, inputs):
numpy_inputs = apply_to_collection(inputs, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)