Sync our torchmetrics wrappers after the 0.4 release (#8205)

Co-authored-by: Jirka <jirka.borovec@seznam.cz>
This commit is contained in:
Carlos Mocholí 2021-06-30 00:05:48 +02:00 committed by GitHub
parent 9444a08d56
commit 47c76548aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 13 additions and 13 deletions

View File

@ -21,7 +21,7 @@ from pytorch_lightning.metrics.utils import deprecated_metrics, void
class FBeta(_FBeta):
@deprecated_metrics(target=_FBeta)
@deprecated_metrics(target=_FBeta, args_mapping={"multilabel": None})
def __init__(
self,
num_classes: int,
@ -44,7 +44,7 @@ class FBeta(_FBeta):
class F1(_F1):
@deprecated_metrics(target=_F1)
@deprecated_metrics(target=_F1, args_mapping={"multilabel": None})
def __init__(
self,
num_classes: int,

View File

@ -21,7 +21,7 @@ from pytorch_lightning.metrics.utils import deprecated_metrics, void
class Precision(_Precision):
@deprecated_metrics(target=_Precision)
@deprecated_metrics(target=_Precision, args_mapping={"multilabel": None, "is_multiclass": None})
def __init__(
self,
num_classes: Optional[int] = None,
@ -49,7 +49,7 @@ class Precision(_Precision):
class Recall(_Recall):
@deprecated_metrics(target=_Recall)
@deprecated_metrics(target=_Recall, args_mapping={"multilabel": None, "is_multiclass": None})
def __init__(
self,
num_classes: Optional[int] = None,

View File

@ -20,7 +20,7 @@ from pytorch_lightning.metrics.utils import deprecated_metrics, void
class StatScores(_StatScores):
@deprecated_metrics(target=_StatScores)
@deprecated_metrics(target=_StatScores, args_mapping={"is_multiclass": None})
def __init__(
self,
threshold: float = 0.5,

View File

@ -20,7 +20,7 @@ from torchmetrics.functional import fbeta as _fbeta
from pytorch_lightning.metrics.utils import deprecated_metrics, void
@deprecated_metrics(target=_fbeta)
@deprecated_metrics(target=_fbeta, args_mapping={"multilabel": None})
def fbeta(
preds: torch.Tensor,
target: torch.Tensor,
@ -37,7 +37,7 @@ def fbeta(
return void(preds, target, num_classes, beta, threshold, average, multilabel)
@deprecated_metrics(target=_f1)
@deprecated_metrics(target=_f1, args_mapping={"multilabel": None})
def f1(
preds: torch.Tensor,
target: torch.Tensor,

View File

@ -21,7 +21,7 @@ from torchmetrics.functional import recall as _recall
from pytorch_lightning.metrics.utils import deprecated_metrics, void
@deprecated_metrics(target=_precision)
@deprecated_metrics(target=_precision, args_mapping={"is_multiclass": None})
def precision(
preds: torch.Tensor,
target: torch.Tensor,
@ -40,7 +40,7 @@ def precision(
return void(preds, target, average, mdmc_average, ignore_index, num_classes, threshold, top_k, is_multiclass)
@deprecated_metrics(target=_recall)
@deprecated_metrics(target=_recall, args_mapping={"is_multiclass": None})
def recall(
preds: torch.Tensor,
target: torch.Tensor,
@ -59,7 +59,7 @@ def recall(
return void(preds, target, average, mdmc_average, ignore_index, num_classes, threshold, top_k, is_multiclass)
@deprecated_metrics(target=_precision_recall)
@deprecated_metrics(target=_precision_recall, args_mapping={"is_multiclass": None})
def precision_recall(
preds: torch.Tensor,
target: torch.Tensor,

View File

@ -19,7 +19,7 @@ from torchmetrics.functional import stat_scores as _stat_scores
from pytorch_lightning.metrics.utils import deprecated_metrics, void
@deprecated_metrics(target=_stat_scores)
@deprecated_metrics(target=_stat_scores, args_mapping={"is_multiclass": None})
def stat_scores(
preds: torch.Tensor,
target: torch.Tensor,

View File

@ -65,7 +65,7 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch
return void(prob_tensor, topk, dim)
@deprecated_metrics(target=_to_categorical)
@deprecated_metrics(target=_to_categorical, args_mapping={"tensor": "x"})
def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
"""
.. deprecated::

View File

@ -7,7 +7,7 @@ tqdm>=4.41.0
PyYAML>=5.1,<=5.4.1
fsspec[http]>=2021.05.0, !=2021.06.0
tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file into descriptor pool!'
torchmetrics>=0.4.0rc1
torchmetrics>=0.4.0
pyDeprecate==0.3.1
packaging>=17.0
typing-extensions # TypedDict support for python<3.8