Sync our torchmetrics wrappers after the 0.4 release (#8205)
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
This commit is contained in:
parent
9444a08d56
commit
47c76548aa
|
@ -21,7 +21,7 @@ from pytorch_lightning.metrics.utils import deprecated_metrics, void
|
|||
|
||||
class FBeta(_FBeta):
|
||||
|
||||
@deprecated_metrics(target=_FBeta)
|
||||
@deprecated_metrics(target=_FBeta, args_mapping={"multilabel": None})
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int,
|
||||
|
@ -44,7 +44,7 @@ class FBeta(_FBeta):
|
|||
|
||||
class F1(_F1):
|
||||
|
||||
@deprecated_metrics(target=_F1)
|
||||
@deprecated_metrics(target=_F1, args_mapping={"multilabel": None})
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int,
|
||||
|
|
|
@ -21,7 +21,7 @@ from pytorch_lightning.metrics.utils import deprecated_metrics, void
|
|||
|
||||
class Precision(_Precision):
|
||||
|
||||
@deprecated_metrics(target=_Precision)
|
||||
@deprecated_metrics(target=_Precision, args_mapping={"multilabel": None, "is_multiclass": None})
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
|
@ -49,7 +49,7 @@ class Precision(_Precision):
|
|||
|
||||
class Recall(_Recall):
|
||||
|
||||
@deprecated_metrics(target=_Recall)
|
||||
@deprecated_metrics(target=_Recall, args_mapping={"multilabel": None, "is_multiclass": None})
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
|
|
|
@ -20,7 +20,7 @@ from pytorch_lightning.metrics.utils import deprecated_metrics, void
|
|||
|
||||
class StatScores(_StatScores):
|
||||
|
||||
@deprecated_metrics(target=_StatScores)
|
||||
@deprecated_metrics(target=_StatScores, args_mapping={"is_multiclass": None})
|
||||
def __init__(
|
||||
self,
|
||||
threshold: float = 0.5,
|
||||
|
|
|
@ -20,7 +20,7 @@ from torchmetrics.functional import fbeta as _fbeta
|
|||
from pytorch_lightning.metrics.utils import deprecated_metrics, void
|
||||
|
||||
|
||||
@deprecated_metrics(target=_fbeta)
|
||||
@deprecated_metrics(target=_fbeta, args_mapping={"multilabel": None})
|
||||
def fbeta(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
|
@ -37,7 +37,7 @@ def fbeta(
|
|||
return void(preds, target, num_classes, beta, threshold, average, multilabel)
|
||||
|
||||
|
||||
@deprecated_metrics(target=_f1)
|
||||
@deprecated_metrics(target=_f1, args_mapping={"multilabel": None})
|
||||
def f1(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
|
|
|
@ -21,7 +21,7 @@ from torchmetrics.functional import recall as _recall
|
|||
from pytorch_lightning.metrics.utils import deprecated_metrics, void
|
||||
|
||||
|
||||
@deprecated_metrics(target=_precision)
|
||||
@deprecated_metrics(target=_precision, args_mapping={"is_multiclass": None})
|
||||
def precision(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
|
@ -40,7 +40,7 @@ def precision(
|
|||
return void(preds, target, average, mdmc_average, ignore_index, num_classes, threshold, top_k, is_multiclass)
|
||||
|
||||
|
||||
@deprecated_metrics(target=_recall)
|
||||
@deprecated_metrics(target=_recall, args_mapping={"is_multiclass": None})
|
||||
def recall(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
|
@ -59,7 +59,7 @@ def recall(
|
|||
return void(preds, target, average, mdmc_average, ignore_index, num_classes, threshold, top_k, is_multiclass)
|
||||
|
||||
|
||||
@deprecated_metrics(target=_precision_recall)
|
||||
@deprecated_metrics(target=_precision_recall, args_mapping={"is_multiclass": None})
|
||||
def precision_recall(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
|
|
|
@ -19,7 +19,7 @@ from torchmetrics.functional import stat_scores as _stat_scores
|
|||
from pytorch_lightning.metrics.utils import deprecated_metrics, void
|
||||
|
||||
|
||||
@deprecated_metrics(target=_stat_scores)
|
||||
@deprecated_metrics(target=_stat_scores, args_mapping={"is_multiclass": None})
|
||||
def stat_scores(
|
||||
preds: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
|
|
|
@ -65,7 +65,7 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch
|
|||
return void(prob_tensor, topk, dim)
|
||||
|
||||
|
||||
@deprecated_metrics(target=_to_categorical)
|
||||
@deprecated_metrics(target=_to_categorical, args_mapping={"tensor": "x"})
|
||||
def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
|
||||
"""
|
||||
.. deprecated::
|
||||
|
|
|
@ -7,7 +7,7 @@ tqdm>=4.41.0
|
|||
PyYAML>=5.1,<=5.4.1
|
||||
fsspec[http]>=2021.05.0, !=2021.06.0
|
||||
tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file into descriptor pool!'
|
||||
torchmetrics>=0.4.0rc1
|
||||
torchmetrics>=0.4.0
|
||||
pyDeprecate==0.3.1
|
||||
packaging>=17.0
|
||||
typing-extensions # TypedDict support for python<3.8
|
||||
|
|
Loading…
Reference in New Issue