drop deprecated fbeta metrics (#5322)

* drop deprecated fbeta metrics

* flake8

* imports

* chlog
This commit is contained in:
Jirka Borovec 2021-01-02 01:49:23 +01:00 committed by GitHub
parent fb90eec515
commit f2c2a692e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 6 additions and 94 deletions

View File

@ -32,6 +32,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Removed ### Removed
- Removed deprecated checkpoint argument `filepath` ([#5321](https://github.com/PyTorchLightning/pytorch-lightning/pull/5321))
- Removed deprecated `Fbeta`, `f1_score` and `fbeta_score` metrics ([#5322](https://github.com/PyTorchLightning/pytorch-lightning/pull/5322))
### Fixed ### Fixed

View File

@ -14,7 +14,7 @@
from pytorch_lightning.metrics.classification.accuracy import Accuracy # noqa: F401 from pytorch_lightning.metrics.classification.accuracy import Accuracy # noqa: F401
from pytorch_lightning.metrics.classification.average_precision import AveragePrecision # noqa: F401 from pytorch_lightning.metrics.classification.average_precision import AveragePrecision # noqa: F401
from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401 from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401
from pytorch_lightning.metrics.classification.f_beta import FBeta, Fbeta, F1 # noqa: F401 from pytorch_lightning.metrics.classification.f_beta import FBeta, F1 # noqa: F401
from pytorch_lightning.metrics.classification.hamming_distance import HammingDistance # noqa: F401 from pytorch_lightning.metrics.classification.hamming_distance import HammingDistance # noqa: F401
from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall # noqa: F401 from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall # noqa: F401
from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401 from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401

View File

@ -132,34 +132,6 @@ class FBeta(Metric):
self.actual_positives, self.beta, self.average) self.actual_positives, self.beta, self.average)
# todo: remove in v1.2
class Fbeta(FBeta):
r"""
Computes `F-score <https://en.wikipedia.org/wiki/F-score>`_
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.classification.f_beta.FBeta`
"""
def __init__(
self,
num_classes: int,
beta: float = 1.0,
threshold: float = 0.5,
average: str = "micro",
multilabel: bool = False,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
rank_zero_warn(
"This `Fbeta` was deprecated in v1.0.x in favor of"
" `from pytorch_lightning.metrics.classification.f_beta import FBeta`."
" It will be removed in v1.2.0", DeprecationWarning
)
super().__init__(
num_classes, beta, threshold, average, multilabel, compute_on_step, dist_sync_on_step, process_group
)
class F1(FBeta): class F1(FBeta):
""" """
Computes F1 metric. F1 metrics correspond to a harmonic mean of the Computes F1 metric. F1 metrics correspond to a harmonic mean of the

View File

@ -16,8 +16,6 @@ from pytorch_lightning.metrics.functional.classification import ( # noqa: F401
auc, auc,
auroc, auroc,
dice_score, dice_score,
f1_score,
fbeta_score,
get_num_classes, get_num_classes,
iou, iou,
multiclass_auroc, multiclass_auroc,

View File

@ -18,7 +18,6 @@ import torch
from distutils.version import LooseVersion from distutils.version import LooseVersion
from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap
from pytorch_lightning.metrics.functional.f_beta import fbeta as __fb, f1 as __f1
from pytorch_lightning.metrics.functional.precision_recall_curve import ( from pytorch_lightning.metrics.functional.precision_recall_curve import (
_binary_clf_curve, _binary_clf_curve,
precision_recall_curve as __prc precision_recall_curve as __prc
@ -837,48 +836,3 @@ def average_precision(
" It will be removed in v1.3.0", DeprecationWarning " It will be removed in v1.3.0", DeprecationWarning
) )
return __ap(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) return __ap(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)
# todo: remove in 1.2
def fbeta_score(
pred: torch.Tensor,
target: torch.Tensor,
beta: float,
num_classes: Optional[int] = None,
class_reduction: str = 'micro',
) -> torch.Tensor:
"""
Computes the F-beta score which is a weighted harmonic mean of precision and recall.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.f_beta.fbeta`
"""
rank_zero_warn(
"This `average_precision` was deprecated in v1.0.x in favor of"
" `from pytorch_lightning.metrics.functional.f_beta import fbeta`."
" It will be removed in v1.2.0", DeprecationWarning
)
if num_classes is None:
num_classes = get_num_classes(pred, target)
return __fb(preds=pred, target=target, beta=beta, num_classes=num_classes, average=class_reduction)
# todo: remove in 1.2
def f1_score(
pred: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
class_reduction: str = 'micro',
) -> torch.Tensor:
"""
Computes the F1-score (a.k.a F-measure), which is the harmonic mean of the precision and recall.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.f_beta.f1`
"""
rank_zero_warn(
"This `average_precision` was deprecated in v1.0.x in favor of"
" `from pytorch_lightning.metrics.functional.f_beta import f1`."
" It will be removed in v1.2.0", DeprecationWarning
)
if num_classes is None:
num_classes = get_num_classes(pred, target)
return __f1(preds=pred, target=target, num_classes=num_classes, average=class_reduction)

View File

@ -12,20 +12,3 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Test deprecated functionality which will be removed in vX.Y.Z""" """Test deprecated functionality which will be removed in vX.Y.Z"""
import pytest
import torch
def test_v1_2_0_deprecated_metrics():
from pytorch_lightning.metrics.classification import Fbeta
from pytorch_lightning.metrics.functional.classification import f1_score, fbeta_score
with pytest.deprecated_call(match='will be removed in v1.2'):
Fbeta(2)
with pytest.deprecated_call(match='will be removed in v1.2'):
fbeta_score(torch.tensor([0, 1, 2, 3]), torch.tensor([0, 1, 2, 1]), 0.2)
with pytest.deprecated_call(match='will be removed in v1.2'):
f1_score(torch.tensor([0, 1, 0, 1]), torch.tensor([0, 1, 0, 0]))