diff --git a/pytorch_lightning/metrics/classification/f_beta.py b/pytorch_lightning/metrics/classification/f_beta.py index 1aad67b4e8..58a50f163b 100644 --- a/pytorch_lightning/metrics/classification/f_beta.py +++ b/pytorch_lightning/metrics/classification/f_beta.py @@ -21,7 +21,7 @@ from pytorch_lightning.metrics.utils import deprecated_metrics, void class FBeta(_FBeta): - @deprecated_metrics(target=_FBeta) + @deprecated_metrics(target=_FBeta, args_mapping={"multilabel": None}) def __init__( self, num_classes: int, @@ -44,7 +44,7 @@ class FBeta(_FBeta): class F1(_F1): - @deprecated_metrics(target=_F1) + @deprecated_metrics(target=_F1, args_mapping={"multilabel": None}) def __init__( self, num_classes: int, diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index 112ddaac19..6507f6d071 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -21,7 +21,7 @@ from pytorch_lightning.metrics.utils import deprecated_metrics, void class Precision(_Precision): - @deprecated_metrics(target=_Precision) + @deprecated_metrics(target=_Precision, args_mapping={"multilabel": None, "is_multiclass": None}) def __init__( self, num_classes: Optional[int] = None, @@ -49,7 +49,7 @@ class Precision(_Precision): class Recall(_Recall): - @deprecated_metrics(target=_Recall) + @deprecated_metrics(target=_Recall, args_mapping={"multilabel": None, "is_multiclass": None}) def __init__( self, num_classes: Optional[int] = None, diff --git a/pytorch_lightning/metrics/classification/stat_scores.py b/pytorch_lightning/metrics/classification/stat_scores.py index 3b7f1313e7..806ee73e17 100644 --- a/pytorch_lightning/metrics/classification/stat_scores.py +++ b/pytorch_lightning/metrics/classification/stat_scores.py @@ -20,7 +20,7 @@ from pytorch_lightning.metrics.utils import deprecated_metrics, void class StatScores(_StatScores): - @deprecated_metrics(target=_StatScores) + @deprecated_metrics(target=_StatScores, args_mapping={"is_multiclass": None}) def __init__( self, threshold: float = 0.5, diff --git a/pytorch_lightning/metrics/functional/f_beta.py b/pytorch_lightning/metrics/functional/f_beta.py index 29499a491e..ed3d92e69f 100644 --- a/pytorch_lightning/metrics/functional/f_beta.py +++ b/pytorch_lightning/metrics/functional/f_beta.py @@ -20,7 +20,7 @@ from torchmetrics.functional import fbeta as _fbeta from pytorch_lightning.metrics.utils import deprecated_metrics, void -@deprecated_metrics(target=_fbeta) +@deprecated_metrics(target=_fbeta, args_mapping={"multilabel": None}) def fbeta( preds: torch.Tensor, target: torch.Tensor, @@ -37,7 +37,7 @@ def fbeta( return void(preds, target, num_classes, beta, threshold, average, multilabel) -@deprecated_metrics(target=_f1) +@deprecated_metrics(target=_f1, args_mapping={"multilabel": None}) def f1( preds: torch.Tensor, target: torch.Tensor, diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index d9752f32bc..367c9c9111 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -21,7 +21,7 @@ from torchmetrics.functional import recall as _recall from pytorch_lightning.metrics.utils import deprecated_metrics, void -@deprecated_metrics(target=_precision) +@deprecated_metrics(target=_precision, args_mapping={"is_multiclass": None}) def precision( preds: torch.Tensor, target: torch.Tensor, @@ -40,7 +40,7 @@ def precision( return void(preds, target, average, mdmc_average, ignore_index, num_classes, threshold, top_k, is_multiclass) -@deprecated_metrics(target=_recall) +@deprecated_metrics(target=_recall, args_mapping={"is_multiclass": None}) def recall( preds: torch.Tensor, target: torch.Tensor, @@ -59,7 +59,7 @@ def recall( return void(preds, target, average, mdmc_average, ignore_index, num_classes, threshold, top_k, is_multiclass) -@deprecated_metrics(target=_precision_recall) +@deprecated_metrics(target=_precision_recall, args_mapping={"is_multiclass": None}) def precision_recall( preds: torch.Tensor, target: torch.Tensor, diff --git a/pytorch_lightning/metrics/functional/stat_scores.py b/pytorch_lightning/metrics/functional/stat_scores.py index 7739223603..da654a54e3 100644 --- a/pytorch_lightning/metrics/functional/stat_scores.py +++ b/pytorch_lightning/metrics/functional/stat_scores.py @@ -19,7 +19,7 @@ from torchmetrics.functional import stat_scores as _stat_scores from pytorch_lightning.metrics.utils import deprecated_metrics, void -@deprecated_metrics(target=_stat_scores) +@deprecated_metrics(target=_stat_scores, args_mapping={"is_multiclass": None}) def stat_scores( preds: torch.Tensor, target: torch.Tensor, diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index b87866a740..dd58e59751 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -65,7 +65,7 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch return void(prob_tensor, topk, dim) -@deprecated_metrics(target=_to_categorical) +@deprecated_metrics(target=_to_categorical, args_mapping={"tensor": "x"}) def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: """ .. deprecated:: diff --git a/requirements.txt b/requirements.txt index e6b3730366..af311aa4c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ tqdm>=4.41.0 PyYAML>=5.1,<=5.4.1 fsspec[http]>=2021.05.0, !=2021.06.0 tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file into descriptor pool!' -torchmetrics>=0.4.0rc1 +torchmetrics>=0.4.0 pyDeprecate==0.3.1 packaging>=17.0 typing-extensions # TypedDict support for python<3.8