Add remaning sklearn metrics (#2562)
* added balanced accuracy * added dcg score * added mean absolute error * added mean squared error * fix * added mean squared log error * add median absolute error and r2 score * switch arguments * added mean poisson deviance * add mean gamma deviance and mean tweedie deviance * fix styling * added explained variance score * added cohen kappa score * added hamming, hinge, jaccard * fix styling * update sklearn requirement to newer version * update requirement * fix doctest * fix tests * added balanced accuracy * added dcg score * added mean absolute error * added mean squared error * fix * added mean squared log error * add median absolute error and r2 score * switch arguments * added mean poisson deviance * add mean gamma deviance and mean tweedie deviance * fix styling * added explained variance score * added cohen kappa score * added hamming, hinge, jaccard * fix styling * update sklearn requirement to newer version * fix doctest * fix tests * fix doctest * fix failing docs * fix test * trying to fix errors * Apply suggestions from code review * format Co-authored-by: Nicki Skafte <nugginea@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
This commit is contained in:
parent
ad0f1194aa
commit
e3732789d7
|
@ -29,6 +29,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Added support for Mean in DDP Sync ([#2568](https://github.com/PyTorchLightning/pytorch-lightning/pull/2568))
|
||||
|
||||
- Added remaining `sklearn` metrics: `AveragePrecision`, `BalancedAccuracy`, `CohenKappaScore`, `DCG`, `Hamming`, `Hinge`, `Jaccard`, `MeanAbsoluteError`, `MeanSquaredError`, `MeanSquaredLogError`, `MedianAbsoluteError`, `R2Score`, `MeanPoissonDeviance`, `MeanGammaDeviance`, `MeanTweedieDeviance`, `ExplainedVariance` ([#2562](https://github.com/PyTorchLightning/pytorch-lightning/pull/2562))
|
||||
|
||||
### Changed
|
||||
|
||||
- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594))
|
||||
|
|
|
@ -484,6 +484,17 @@ AveragePrecision (sk)
|
|||
.. autofunction:: pytorch_lightning.metrics.sklearns.AveragePrecision
|
||||
:noindex:
|
||||
|
||||
BalancedAccuracy (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.BalancedAccuracy
|
||||
:noindex:
|
||||
|
||||
CohenKappaScore (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.CohenKappaScore
|
||||
:noindex:
|
||||
|
||||
ConfusionMatrix (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
@ -491,6 +502,12 @@ ConfusionMatrix (sk)
|
|||
.. autofunction:: pytorch_lightning.metrics.sklearns.ConfusionMatrix
|
||||
:noindex:
|
||||
|
||||
DCG (sk)
|
||||
^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.DCG
|
||||
:noindex:
|
||||
|
||||
F1 (sk)
|
||||
^^^^^^^
|
||||
|
||||
|
@ -503,6 +520,24 @@ FBeta (sk)
|
|||
.. autofunction:: pytorch_lightning.metrics.sklearns.FBeta
|
||||
:noindex:
|
||||
|
||||
Hamming (sk)
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.Hamming
|
||||
:noindex:
|
||||
|
||||
Hinge (sk)
|
||||
^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.Hinge
|
||||
:noindex:
|
||||
|
||||
Jaccard (sk)
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.Jaccard
|
||||
:noindex:
|
||||
|
||||
Precision (sk)
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
|
@ -532,3 +567,58 @@ AUROC (sk)
|
|||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.AUROC
|
||||
:noindex:
|
||||
|
||||
ExplainedVariance (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.ExplainedVariance
|
||||
:noindex:
|
||||
|
||||
MeanAbsoluteError (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanAbsoluteError
|
||||
:noindex:
|
||||
|
||||
MeanSquaredError (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanSquaredError
|
||||
:noindex:
|
||||
|
||||
MeanSquaredLogError (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanSquaredLogError
|
||||
:noindex:
|
||||
|
||||
MedianAbsoluteError (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.MedianAbsoluteError
|
||||
:noindex:
|
||||
|
||||
R2Score (sk)
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.R2Score
|
||||
:noindex:
|
||||
|
||||
MeanPoissonDeviance (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanPoissonDeviance
|
||||
:noindex:
|
||||
|
||||
MeanGammaDeviance (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanGammaDeviance
|
||||
:noindex:
|
||||
|
||||
MeanTweedieDeviance (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanTweedieDeviance
|
||||
:noindex:
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -10,7 +10,7 @@ matplotlib>=3.1.1
|
|||
horovod>=0.19.2
|
||||
omegaconf>=2.0.0
|
||||
# scipy>=0.13.3
|
||||
scikit-learn>=0.20.0
|
||||
scikit-learn>=0.22.2
|
||||
torchtext>=0.3.1, <0.7 # TODO: temporary fix fix for compatibility
|
||||
onnx>=1.7.0
|
||||
onnxruntime>=1.3.0
|
||||
onnxruntime>=1.3.0
|
||||
|
|
|
@ -16,21 +16,51 @@ from sklearn.metrics import (
|
|||
precision_recall_curve as sk_precision_recall_curve,
|
||||
roc_curve as sk_roc_curve,
|
||||
roc_auc_score as sk_roc_auc_score,
|
||||
balanced_accuracy_score as sk_balanced_accuracy_score,
|
||||
dcg_score as sk_dcg_score,
|
||||
mean_absolute_error as sk_mean_absolute_error,
|
||||
mean_squared_error as sk_mean_squared_error,
|
||||
mean_squared_log_error as sk_mean_squared_log_error,
|
||||
median_absolute_error as sk_median_absolute_error,
|
||||
r2_score as sk_r2_score,
|
||||
mean_poisson_deviance as sk_mean_poisson_deviance,
|
||||
mean_gamma_deviance as sk_mean_gamma_deviance,
|
||||
mean_tweedie_deviance as sk_mean_tweedie_deviance,
|
||||
explained_variance_score as sk_explained_variance_score,
|
||||
cohen_kappa_score as sk_cohen_kappa_score,
|
||||
hamming_loss as sk_hamming_loss,
|
||||
hinge_loss as sk_hinge_loss,
|
||||
jaccard_score as sk_jaccard_score
|
||||
)
|
||||
|
||||
from pytorch_lightning.metrics.converters import _convert_to_numpy
|
||||
from pytorch_lightning.metrics.sklearns import (
|
||||
Accuracy,
|
||||
AveragePrecision,
|
||||
AUC,
|
||||
AveragePrecision,
|
||||
BalancedAccuracy,
|
||||
ConfusionMatrix,
|
||||
CohenKappaScore,
|
||||
DCG,
|
||||
F1,
|
||||
FBeta,
|
||||
Hamming,
|
||||
Hinge,
|
||||
Jaccard,
|
||||
Precision,
|
||||
Recall,
|
||||
PrecisionRecallCurve,
|
||||
ROC,
|
||||
AUROC
|
||||
AUROC,
|
||||
MeanAbsoluteError,
|
||||
MeanSquaredError,
|
||||
MeanSquaredLogError,
|
||||
MedianAbsoluteError,
|
||||
R2Score,
|
||||
MeanPoissonDeviance,
|
||||
MeanGammaDeviance,
|
||||
MeanTweedieDeviance,
|
||||
ExplainedVariance,
|
||||
)
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
||||
|
@ -86,6 +116,51 @@ def _xy_only(func):
|
|||
{'y_score': torch.rand(size=(128,)),
|
||||
'y_true': torch.randint(2, size=(128,))},
|
||||
id='AUROC'),
|
||||
pytest.param(BalancedAccuracy(), sk_balanced_accuracy_score,
|
||||
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
|
||||
id='BalancedAccuracy'),
|
||||
pytest.param(DCG(), sk_dcg_score,
|
||||
{'y_score': torch.rand(size=(128, 3)), 'y_true': torch.randint(3, size=(128, 3))},
|
||||
id='DCG'),
|
||||
pytest.param(ExplainedVariance(), sk_explained_variance_score,
|
||||
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
|
||||
id='ExplainedVariance'),
|
||||
pytest.param(MeanAbsoluteError(), sk_mean_absolute_error,
|
||||
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
|
||||
id='MeanAbsolutError'),
|
||||
pytest.param(MeanSquaredError(), sk_mean_squared_error,
|
||||
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
|
||||
id='MeanSquaredError'),
|
||||
pytest.param(MeanSquaredLogError(), sk_mean_squared_log_error,
|
||||
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
|
||||
id='MeanSquaredLogError'),
|
||||
pytest.param(MedianAbsoluteError(), sk_median_absolute_error,
|
||||
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
|
||||
id='MedianAbsoluteError'),
|
||||
pytest.param(R2Score(), sk_r2_score,
|
||||
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
|
||||
id='R2Score'),
|
||||
pytest.param(MeanPoissonDeviance(), sk_mean_poisson_deviance,
|
||||
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
|
||||
id='MeanPoissonDeviance'),
|
||||
pytest.param(MeanGammaDeviance(), sk_mean_gamma_deviance,
|
||||
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
|
||||
id='MeanGammaDeviance'),
|
||||
pytest.param(MeanTweedieDeviance(), sk_mean_tweedie_deviance,
|
||||
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
|
||||
id='MeanTweedieDeviance'),
|
||||
pytest.param(CohenKappaScore(), sk_cohen_kappa_score,
|
||||
{'y1': torch.randint(3, size=(128,)), 'y2': torch.randint(3, size=(128,))},
|
||||
id='CohenKappaScore'),
|
||||
pytest.param(Hamming(), sk_hamming_loss,
|
||||
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
|
||||
id='Hamming'),
|
||||
pytest.param(Hinge(), sk_hinge_loss,
|
||||
{'pred_decision': torch.randn(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
|
||||
id='Hinge'),
|
||||
pytest.param(Jaccard(average='macro'), partial(sk_jaccard_score, average='macro'),
|
||||
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
|
||||
id='Jaccard')
|
||||
])
|
||||
def test_sklearn_metric(metric_class, sklearn_func, inputs):
|
||||
numpy_inputs = apply_to_collection(inputs, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)
|
||||
|
|
Loading…
Reference in New Issue