Metric aggregation (#3321)
* metric aggregation * metric aggregation * add at_least_1d * fix output formatting * add metric tests * add missing test case * remove reduce_op frm metric classes * fix reduce_op stuff * start test fixing * fix tests due to aggregation * fix faulty import * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * remove reduce_op docstrings * add compute * remove import * remove collection metric * update base class * update tests * Update metric.py * Update metric.py * Apply suggestions from code review * change default aggregate Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
This commit is contained in:
parent
50b8388f03
commit
4dc4c8cfa5
|
@ -30,9 +30,9 @@ from pytorch_lightning.metrics.functional.classification import (
|
|||
precision,
|
||||
precision_recall_curve,
|
||||
recall,
|
||||
roc
|
||||
roc,
|
||||
)
|
||||
from pytorch_lightning.metrics.metric import TensorCollectionMetric, TensorMetric
|
||||
from pytorch_lightning.metrics.metric import TensorMetric
|
||||
|
||||
|
||||
class Accuracy(TensorMetric):
|
||||
|
@ -44,17 +44,16 @@ class Accuracy(TensorMetric):
|
|||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = Accuracy()
|
||||
>>> metric(pred, target)
|
||||
tensor(0.7500)
|
||||
>>> metric(pred, target).item()
|
||||
0.75
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = "elementwise_mean",
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -65,11 +64,8 @@ class Accuracy(TensorMetric):
|
|||
- none: pass array
|
||||
- sum: add elements
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
"""
|
||||
super().__init__(name='accuracy',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
super().__init__(name="accuracy", reduce_group=reduce_group)
|
||||
self.num_classes = num_classes
|
||||
self.reduction = reduction
|
||||
|
||||
|
@ -84,8 +80,7 @@ class Accuracy(TensorMetric):
|
|||
Return:
|
||||
A Tensor with the classification score.
|
||||
"""
|
||||
return accuracy(pred=pred, target=target,
|
||||
num_classes=self.num_classes, reduction=self.reduction)
|
||||
return accuracy(pred=pred, target=target, num_classes=self.num_classes, reduction=self.reduction)
|
||||
|
||||
|
||||
class ConfusionMatrix(TensorMetric):
|
||||
|
@ -106,22 +101,21 @@ class ConfusionMatrix(TensorMetric):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
normalize: bool = False,
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
normalize: bool = False,
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: number of classes
|
||||
normalize: whether to compute a normalized confusion matrix
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
"""
|
||||
super().__init__(name='confusion_matrix',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
super().__init__(
|
||||
name="confusion_matrix",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
self.normalize = normalize
|
||||
self.num_classes = num_classes
|
||||
|
||||
|
@ -140,8 +134,16 @@ class ConfusionMatrix(TensorMetric):
|
|||
normalize=self.normalize,
|
||||
num_classes=self.num_classes)
|
||||
|
||||
def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor:
|
||||
"""Aggregates results by stacking them instead of concatenating before averaging.
|
||||
|
||||
class PrecisionRecallCurve(TensorCollectionMetric):
|
||||
Returns:
|
||||
the aggregated results
|
||||
"""
|
||||
return torch.stack(tensors).mean(0)
|
||||
|
||||
|
||||
class PrecisionRecallCurve(TensorMetric):
|
||||
"""
|
||||
Computes the precision recall curve
|
||||
|
||||
|
@ -161,28 +163,27 @@ class PrecisionRecallCurve(TensorCollectionMetric):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pos_label: int = 1,
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
self,
|
||||
pos_label: int = 1,
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
pos_label: positive label indicator
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
"""
|
||||
super().__init__(name='precision_recall_curve',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
super().__init__(
|
||||
name="precision_recall_curve",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.pos_label = pos_label
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
@ -197,9 +198,7 @@ class PrecisionRecallCurve(TensorCollectionMetric):
|
|||
- recall values
|
||||
- threshold values
|
||||
"""
|
||||
return precision_recall_curve(pred=pred, target=target,
|
||||
sample_weight=sample_weight,
|
||||
pos_label=self.pos_label)
|
||||
return precision_recall_curve(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label)
|
||||
|
||||
|
||||
class Precision(TensorMetric):
|
||||
|
@ -217,11 +216,10 @@ class Precision(TensorMetric):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = "elementwise_mean",
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -232,11 +230,11 @@ class Precision(TensorMetric):
|
|||
- none: pass array
|
||||
- sum: add elements
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
"""
|
||||
super().__init__(name='precision',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
super().__init__(
|
||||
name="precision",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
self.reduction = reduction
|
||||
|
||||
|
@ -251,9 +249,7 @@ class Precision(TensorMetric):
|
|||
Return:
|
||||
A Tensor with the classification score.
|
||||
"""
|
||||
return precision(pred=pred, target=target,
|
||||
num_classes=self.num_classes,
|
||||
reduction=self.reduction)
|
||||
return precision(pred=pred, target=target, num_classes=self.num_classes, reduction=self.reduction)
|
||||
|
||||
|
||||
class Recall(TensorMetric):
|
||||
|
@ -271,11 +267,10 @@ class Recall(TensorMetric):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = "elementwise_mean",
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -286,11 +281,11 @@ class Recall(TensorMetric):
|
|||
- none: pass array
|
||||
- sum: add elements
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
"""
|
||||
super().__init__(name='recall',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
super().__init__(
|
||||
name="recall",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.reduction = reduction
|
||||
|
@ -306,10 +301,7 @@ class Recall(TensorMetric):
|
|||
Return:
|
||||
A Tensor with the classification score.
|
||||
"""
|
||||
return recall(pred=pred,
|
||||
target=target,
|
||||
num_classes=self.num_classes,
|
||||
reduction=self.reduction)
|
||||
return recall(pred=pred, target=target, num_classes=self.num_classes, reduction=self.reduction)
|
||||
|
||||
|
||||
class AveragePrecision(TensorMetric):
|
||||
|
@ -327,28 +319,24 @@ class AveragePrecision(TensorMetric):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pos_label: int = 1,
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
self,
|
||||
pos_label: int = 1,
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
pos_label: positive label indicator
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
"""
|
||||
super().__init__(name='AP',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
super().__init__(
|
||||
name="AP",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.pos_label = pos_label
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None
|
||||
self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
@ -361,9 +349,7 @@ class AveragePrecision(TensorMetric):
|
|||
Return:
|
||||
torch.Tensor: classification score
|
||||
"""
|
||||
return average_precision(pred=pred, target=target,
|
||||
sample_weight=sample_weight,
|
||||
pos_label=self.pos_label)
|
||||
return average_precision(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label)
|
||||
|
||||
|
||||
class AUROC(TensorMetric):
|
||||
|
@ -381,28 +367,24 @@ class AUROC(TensorMetric):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pos_label: int = 1,
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
self,
|
||||
pos_label: int = 1,
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
pos_label: positive label indicator
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
"""
|
||||
super().__init__(name='auroc',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
super().__init__(
|
||||
name="auroc",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.pos_label = pos_label
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None
|
||||
self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
@ -415,9 +397,7 @@ class AUROC(TensorMetric):
|
|||
Return:
|
||||
torch.Tensor: classification score
|
||||
"""
|
||||
return auroc(pred=pred, target=target,
|
||||
sample_weight=sample_weight,
|
||||
pos_label=self.pos_label)
|
||||
return auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label)
|
||||
|
||||
|
||||
class FBeta(TensorMetric):
|
||||
|
@ -435,12 +415,11 @@ class FBeta(TensorMetric):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
beta: float,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
self,
|
||||
beta: float,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = "elementwise_mean",
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -452,11 +431,11 @@ class FBeta(TensorMetric):
|
|||
- none: pass array
|
||||
- sum: add elements
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for DDP reduction
|
||||
"""
|
||||
super().__init__(name='fbeta',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
super().__init__(
|
||||
name="fbeta",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.beta = beta
|
||||
self.num_classes = num_classes
|
||||
|
@ -473,9 +452,9 @@ class FBeta(TensorMetric):
|
|||
Return:
|
||||
torch.Tensor: classification score
|
||||
"""
|
||||
return fbeta_score(pred=pred, target=target,
|
||||
beta=self.beta, num_classes=self.num_classes,
|
||||
reduction=self.reduction)
|
||||
return fbeta_score(
|
||||
pred=pred, target=target, beta=self.beta, num_classes=self.num_classes, reduction=self.reduction
|
||||
)
|
||||
|
||||
|
||||
class F1(TensorMetric):
|
||||
|
@ -493,11 +472,10 @@ class F1(TensorMetric):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = 'elementwise_mean',
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = "elementwise_mean",
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -508,11 +486,11 @@ class F1(TensorMetric):
|
|||
- none: pass array
|
||||
- sum: add elements
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
"""
|
||||
super().__init__(name='f1',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
super().__init__(
|
||||
name="f1",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.reduction = reduction
|
||||
|
@ -528,12 +506,10 @@ class F1(TensorMetric):
|
|||
Return:
|
||||
torch.Tensor: classification score
|
||||
"""
|
||||
return f1_score(pred=pred, target=target,
|
||||
num_classes=self.num_classes,
|
||||
reduction=self.reduction)
|
||||
return f1_score(pred=pred, target=target, num_classes=self.num_classes, reduction=self.reduction)
|
||||
|
||||
|
||||
class ROC(TensorCollectionMetric):
|
||||
class ROC(TensorMetric):
|
||||
"""
|
||||
Computes the Receiver Operator Characteristic (ROC)
|
||||
|
||||
|
@ -553,28 +529,24 @@ class ROC(TensorCollectionMetric):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pos_label: int = 1,
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
self,
|
||||
pos_label: int = 1,
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
pos_label: positive label indicator
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
"""
|
||||
super().__init__(name='roc',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
super().__init__(
|
||||
name="roc",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.pos_label = pos_label
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None
|
||||
self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
@ -589,12 +561,10 @@ class ROC(TensorCollectionMetric):
|
|||
- true positive rate
|
||||
- thresholds
|
||||
"""
|
||||
return roc(pred=pred, target=target,
|
||||
sample_weight=sample_weight,
|
||||
pos_label=self.pos_label)
|
||||
return roc(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label)
|
||||
|
||||
|
||||
class MulticlassROC(TensorCollectionMetric):
|
||||
class MulticlassROC(TensorMetric):
|
||||
"""
|
||||
Computes the multiclass ROC
|
||||
|
||||
|
@ -615,27 +585,27 @@ class MulticlassROC(TensorCollectionMetric):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: number of classes
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
"""
|
||||
super().__init__(name='multiclass_roc',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
super().__init__(
|
||||
name="multiclass_roc",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(
|
||||
self, pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
@ -649,13 +619,19 @@ class MulticlassROC(TensorCollectionMetric):
|
|||
tuple: A tuple consisting of one tuple per class, holding false positive rate, true positive rate and thresholds
|
||||
|
||||
"""
|
||||
return multiclass_roc(pred=pred,
|
||||
target=target,
|
||||
sample_weight=sample_weight,
|
||||
num_classes=self.num_classes)
|
||||
return multiclass_roc(pred=pred, target=target, sample_weight=sample_weight, num_classes=self.num_classes)
|
||||
|
||||
def aggregate(self, *tensors: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
"""Aggregates results by stacking them instead of concatenating before averaging.
|
||||
|
||||
Returns:
|
||||
the aggregated results
|
||||
"""
|
||||
|
||||
return tuple([tuple([torch.stack(tmps).mean(0) for tmps in zip(*_tensors)]) for _tensors in zip(*tensors)])
|
||||
|
||||
|
||||
class MulticlassPrecisionRecallCurve(TensorCollectionMetric):
|
||||
class MulticlassPrecisionRecallCurve(TensorMetric):
|
||||
"""Computes the multiclass PR Curve
|
||||
|
||||
Example:
|
||||
|
@ -674,29 +650,28 @@ class MulticlassPrecisionRecallCurve(TensorCollectionMetric):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: number of classes
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
|
||||
"""
|
||||
super().__init__(name='multiclass_precision_recall_curve',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
super().__init__(
|
||||
name="multiclass_precision_recall_curve",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
@ -710,10 +685,18 @@ class MulticlassPrecisionRecallCurve(TensorCollectionMetric):
|
|||
tuple: A tuple consisting of one tuple per class, holding precision, recall and thresholds
|
||||
|
||||
"""
|
||||
return multiclass_precision_recall_curve(pred=pred,
|
||||
target=target,
|
||||
sample_weight=sample_weight,
|
||||
num_classes=self.num_classes)
|
||||
return multiclass_precision_recall_curve(
|
||||
pred=pred, target=target, sample_weight=sample_weight, num_classes=self.num_classes
|
||||
)
|
||||
|
||||
def aggregate(self, *tensors: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
"""Aggregates results by stacking them instead of concatenating before averaging.
|
||||
|
||||
Returns:
|
||||
the aggregated results
|
||||
"""
|
||||
|
||||
return tuple([tuple([torch.stack(tmps).mean(0) for tmps in zip(*_tensors)]) for _tensors in zip(*tensors)])
|
||||
|
||||
|
||||
class DiceCoefficient(TensorMetric):
|
||||
|
@ -733,12 +716,12 @@ class DiceCoefficient(TensorMetric):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
include_background: bool = False,
|
||||
nan_score: float = 0.0, no_fg_score: float = 0.0,
|
||||
reduction: str = 'elementwise_mean',
|
||||
reduce_group: Any = None,
|
||||
reduce_op: Any = None,
|
||||
self,
|
||||
include_background: bool = False,
|
||||
nan_score: float = 0.0,
|
||||
no_fg_score: float = 0.0,
|
||||
reduction: str = "elementwise_mean",
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -751,11 +734,11 @@ class DiceCoefficient(TensorMetric):
|
|||
- none: pass array
|
||||
- sum: add elements
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
reduce_op: the operation to perform for ddp reduction
|
||||
"""
|
||||
super().__init__(name='dice',
|
||||
reduce_group=reduce_group,
|
||||
reduce_op=reduce_op)
|
||||
super().__init__(
|
||||
name="dice",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.include_background = include_background
|
||||
self.nan_score = nan_score
|
||||
|
@ -773,12 +756,14 @@ class DiceCoefficient(TensorMetric):
|
|||
Return:
|
||||
torch.Tensor: the calculated dice coefficient
|
||||
"""
|
||||
return dice_score(pred=pred,
|
||||
target=target,
|
||||
bg=self.include_background,
|
||||
nan_score=self.nan_score,
|
||||
no_fg_score=self.no_fg_score,
|
||||
reduction=self.reduction)
|
||||
return dice_score(
|
||||
pred=pred,
|
||||
target=target,
|
||||
bg=self.include_background,
|
||||
nan_score=self.nan_score,
|
||||
no_fg_score=self.no_fg_score,
|
||||
reduction=self.reduction,
|
||||
)
|
||||
|
||||
|
||||
class IoU(TensorMetric):
|
||||
|
@ -799,11 +784,7 @@ class IoU(TensorMetric):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
remove_bg: bool = False,
|
||||
reduction: str = 'elementwise_mean'
|
||||
):
|
||||
def __init__(self, remove_bg: bool = False, reduction: str = "elementwise_mean"):
|
||||
"""
|
||||
Args:
|
||||
remove_bg: Flag to state whether a background class has been included
|
||||
|
@ -817,12 +798,11 @@ class IoU(TensorMetric):
|
|||
- none: pass array
|
||||
- sum: add elements
|
||||
"""
|
||||
super().__init__(name='iou')
|
||||
super().__init__(name="iou")
|
||||
self.remove_bg = remove_bg
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor,
|
||||
sample_weight: Optional[torch.Tensor] = None):
|
||||
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor, sample_weight: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Actual metric calculation.
|
||||
"""
|
||||
|
|
|
@ -18,6 +18,7 @@ conversion to/from :class:`numpy.ndarray` and :class:`torch.Tensor` as well as u
|
|||
sync tensors between different processes in a DDP scenario, when needed.
|
||||
"""
|
||||
|
||||
from functools import reduce
|
||||
import numbers
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
|
@ -31,10 +32,11 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection
|
|||
try:
|
||||
from torch.distributed import ReduceOp
|
||||
except ImportError:
|
||||
|
||||
class ReduceOp:
|
||||
SUM = None
|
||||
|
||||
rank_zero_warn('Unsupported `ReduceOp` for distributed computing')
|
||||
rank_zero_warn("Unsupported `ReduceOp` for distributed computing")
|
||||
|
||||
|
||||
def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
|
||||
|
@ -138,8 +140,9 @@ def _numpy_metric_input_conversion(func_to_decorate: Callable) -> Callable:
|
|||
Return:
|
||||
Callable: the decorated function
|
||||
"""
|
||||
return _apply_to_inputs(
|
||||
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)(func_to_decorate)
|
||||
return _apply_to_inputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)(
|
||||
func_to_decorate
|
||||
)
|
||||
|
||||
|
||||
def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable:
|
||||
|
@ -185,8 +188,9 @@ def _tensor_metric_input_conversion(func_to_decorate: Callable) -> Callable:
|
|||
Return:
|
||||
Callable: the decorated function
|
||||
"""
|
||||
return _apply_to_inputs(
|
||||
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(func_to_decorate)
|
||||
return _apply_to_inputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(
|
||||
func_to_decorate
|
||||
)
|
||||
|
||||
|
||||
def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> Callable:
|
||||
|
@ -199,8 +203,9 @@ def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> C
|
|||
Return:
|
||||
Callable: the decorated function
|
||||
"""
|
||||
return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number),
|
||||
convert_to_tensor)(func_to_decorate)
|
||||
return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(
|
||||
func_to_decorate
|
||||
)
|
||||
|
||||
|
||||
def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
|
||||
|
@ -240,10 +245,9 @@ def _tensor_collection_metric_conversion(func_to_decorate: Callable) -> Callable
|
|||
return _tensor_collection_metric_output_conversion(func_convert_inputs)
|
||||
|
||||
|
||||
def sync_ddp_if_available(result: Union[torch.Tensor],
|
||||
group: Optional[Any] = None,
|
||||
reduce_op: Optional[ReduceOp] = None
|
||||
) -> torch.Tensor:
|
||||
def sync_ddp_if_available(
|
||||
result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Function to reduce the tensors from several ddp processes to one master process
|
||||
|
||||
|
@ -265,14 +269,13 @@ def sync_ddp_if_available(result: Union[torch.Tensor],
|
|||
|
||||
if reduce_op is None:
|
||||
reduce_op = torch.distributed.ReduceOp.SUM
|
||||
elif isinstance(reduce_op, str) and reduce_op in ('avg', 'mean'):
|
||||
elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
|
||||
reduce_op = torch.distributed.ReduceOp.SUM
|
||||
divide_by_world_size = True
|
||||
|
||||
# sync all processes before reduction
|
||||
torch.distributed.barrier(group=group)
|
||||
torch.distributed.all_reduce(result, op=reduce_op, group=group,
|
||||
async_op=False)
|
||||
torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False)
|
||||
|
||||
if divide_by_world_size:
|
||||
result = result / torch.distributed.get_world_size(group)
|
||||
|
@ -280,8 +283,21 @@ def sync_ddp_if_available(result: Union[torch.Tensor],
|
|||
return result
|
||||
|
||||
|
||||
def gather_all_tensors_if_available(result: Union[torch.Tensor],
|
||||
group: Optional[Any] = None):
|
||||
def at_least_1d(tensor: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
"""Makes sure the tensor is at least of 1d shape
|
||||
|
||||
Args:
|
||||
tensor: the tensor or array to check the shape for
|
||||
|
||||
Returns:
|
||||
the optionally reshaped tensor
|
||||
"""
|
||||
if tensor.shape == ():
|
||||
tensor = tensor.reshape(1, )
|
||||
return tensor
|
||||
|
||||
|
||||
def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None):
|
||||
"""
|
||||
Function to gather all tensors from several ddp processes onto a list that
|
||||
is broadcasted to all processes
|
||||
|
@ -312,8 +328,7 @@ def gather_all_tensors_if_available(result: Union[torch.Tensor],
|
|||
return result
|
||||
|
||||
|
||||
def sync_ddp(group: Optional[Any] = None,
|
||||
reduce_op: Optional[ReduceOp] = None) -> Callable:
|
||||
def sync_ddp(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
|
||||
"""
|
||||
This decorator syncs a functions outputs across different processes for DDP.
|
||||
|
||||
|
@ -327,15 +342,14 @@ def sync_ddp(group: Optional[Any] = None,
|
|||
"""
|
||||
|
||||
def decorator_fn(func_to_decorate):
|
||||
return _apply_to_outputs(apply_to_collection, torch.Tensor,
|
||||
sync_ddp_if_available, group=group,
|
||||
reduce_op=reduce_op)(func_to_decorate)
|
||||
return _apply_to_outputs(
|
||||
apply_to_collection, torch.Tensor, sync_ddp_if_available, group=group, reduce_op=reduce_op
|
||||
)(func_to_decorate)
|
||||
|
||||
return decorator_fn
|
||||
|
||||
|
||||
def numpy_metric(group: Optional[Any] = None,
|
||||
reduce_op: Optional[ReduceOp] = None) -> Callable:
|
||||
def numpy_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
|
||||
"""
|
||||
This decorator shall be used on all function metrics working on numpy arrays.
|
||||
It handles the argument conversion and DDP reduction for metrics working on numpy.
|
||||
|
@ -357,8 +371,7 @@ def numpy_metric(group: Optional[Any] = None,
|
|||
return decorator_fn
|
||||
|
||||
|
||||
def tensor_metric(group: Optional[Any] = None,
|
||||
reduce_op: Optional[ReduceOp] = None) -> Callable:
|
||||
def tensor_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
|
||||
"""
|
||||
This decorator shall be used on all function metrics working on tensors.
|
||||
It handles the argument conversion and DDP reduction for metrics working on tensors.
|
||||
|
@ -379,8 +392,7 @@ def tensor_metric(group: Optional[Any] = None,
|
|||
return decorator_fn
|
||||
|
||||
|
||||
def tensor_collection_metric(group: Optional[Any] = None,
|
||||
reduce_op: Optional[ReduceOp] = None) -> Callable:
|
||||
def tensor_collection_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
|
||||
"""
|
||||
This decorator shall be used on all function metrics working on tensors and returning collections
|
||||
that cannot be converted to tensors.
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Mapping, Optional, Sequence
|
||||
import numbers
|
||||
|
||||
import torch
|
||||
|
@ -21,8 +21,11 @@ from torch import nn
|
|||
import numpy as np
|
||||
|
||||
from pytorch_lightning.metrics.converters import (
|
||||
sync_ddp_if_available, gather_all_tensors_if_available,
|
||||
convert_to_tensor, convert_to_numpy)
|
||||
at_least_1d,
|
||||
gather_all_tensors_if_available,
|
||||
convert_to_tensor,
|
||||
convert_to_numpy,
|
||||
)
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
|
||||
|
||||
|
@ -40,32 +43,41 @@ class Metric(DeviceDtypeModuleMixin, nn.Module, ABC):
|
|||
|
||||
* input_convert: pre-forward hook that takes care of input conversion
|
||||
* output_convert: post-forward hook that takes care of output convertion
|
||||
* ddp_sync: implementation of ddp sync, default is gather all
|
||||
* aggregate: implement how values should be aggregated
|
||||
* ddp_reduce: implementation of ddp sync + aggregation, default is ddp_sync + aggregate
|
||||
* compute: post-ddp sync for additional metric computations
|
||||
|
||||
``ddp_reduce`` by default calls the following methods, which can also be overwritten if necessary.
|
||||
|
||||
* ddp_sync: implements how values should be synced across ddp-processes. Defaults to gather all.
|
||||
* aggregate: implement how values should be aggregated (defaults to mean).
|
||||
|
||||
Call order
|
||||
|
||||
input_convert -> forward -> output_convert -> ddp_sync -> aggregate -> compute
|
||||
input_convert -> forward -> output_convert -> ddp_reduce (per default being ddp_sync -> aggregate) -> compute
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str):
|
||||
def __init__(self, name: str, reduce_group: Optional[Any] = None):
|
||||
"""
|
||||
Args:
|
||||
name: the metric's name
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self._dtype = torch.get_default_dtype()
|
||||
self._device = torch.device('cpu')
|
||||
self._device = torch.device("cpu")
|
||||
|
||||
self.reduce_group = reduce_group
|
||||
|
||||
self._step_vals = []
|
||||
|
||||
# Register hooks
|
||||
self.register_forward_pre_hook(self.input_convert)
|
||||
self.register_forward_hook(self.output_convert)
|
||||
self.register_forward_hook(self.ddp_sync)
|
||||
self.register_forward_hook(self.aggregate)
|
||||
self.register_forward_hook(self.ddp_reduce)
|
||||
self.register_forward_hook(self.compute)
|
||||
|
||||
@staticmethod
|
||||
|
@ -104,12 +116,30 @@ class Metric(DeviceDtypeModuleMixin, nn.Module, ABC):
|
|||
Returns:
|
||||
casted outputs
|
||||
"""
|
||||
return output
|
||||
return apply_to_collection(output, (torch.Tensor, np.ndarray), at_least_1d)
|
||||
|
||||
@staticmethod
|
||||
def ddp_sync(self, data: Any, output: Any):
|
||||
def ddp_sync(self, tensor: Any):
|
||||
"""
|
||||
Implement how the outputs from forward should be synced
|
||||
(per default just gathers all of them and adds them to self._step_vals)
|
||||
|
||||
Args:
|
||||
tensor: tensor to sync
|
||||
|
||||
Returns:
|
||||
synced output
|
||||
|
||||
"""
|
||||
gathered_tensors = apply_to_collection(tensor, torch.Tensor, gather_all_tensors_if_available, self.reduce_group)
|
||||
|
||||
self._step_vals.append(gathered_tensors)
|
||||
|
||||
return gathered_tensors
|
||||
|
||||
@staticmethod
|
||||
def ddp_reduce(self, data: Any, output: Any):
|
||||
"""
|
||||
Implement how the outputs from forward should be synced and reduced across nodes
|
||||
|
||||
Args:
|
||||
data: input to forward method
|
||||
|
@ -119,27 +149,36 @@ class Metric(DeviceDtypeModuleMixin, nn.Module, ABC):
|
|||
synced output
|
||||
|
||||
"""
|
||||
return output
|
||||
synced = self.ddp_sync(output)
|
||||
return self.aggregate(synced)
|
||||
|
||||
@staticmethod
|
||||
def aggregate(self, data: Any, output: Any):
|
||||
def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Implement aggregation of values on the same device
|
||||
|
||||
Args:
|
||||
data: input to forward method
|
||||
output: output from the `ddp_sync` hook
|
||||
tensors: the values to be aggregated
|
||||
|
||||
Returns:
|
||||
aggregated values
|
||||
|
||||
"""
|
||||
return output
|
||||
try:
|
||||
return torch.cat(tensors).mean(0)
|
||||
except (ValueError, TypeError):
|
||||
if isinstance(tensors[0], Mapping):
|
||||
return {k: torch.stack([tensor[k] for tensor in tensors]).mean(0) for k in tensors[0].keys()}
|
||||
elif isinstance(tensors[0], Sequence) and not isinstance(tensors[0], torch.Tensor):
|
||||
return tuple([torch.stack(tmp).mean(0) for tmp in zip(*tensors)])
|
||||
elif isinstance(tensors[0], torch.Tensor):
|
||||
return torch.stack(tensors).mean(0)
|
||||
else:
|
||||
raise TypeError("unknown metric value format to aggregate")
|
||||
|
||||
@staticmethod
|
||||
def compute(self, data: Any, output: Any):
|
||||
"""
|
||||
Implement additionally metric computations to be done after the ddp sync
|
||||
Implement additionally metric computations to be done after the aggregation
|
||||
|
||||
Args:
|
||||
data: input to forward method
|
||||
|
@ -151,6 +190,15 @@ class Metric(DeviceDtypeModuleMixin, nn.Module, ABC):
|
|||
"""
|
||||
return output
|
||||
|
||||
@property
|
||||
def aggregated(self) -> torch.Tensor:
|
||||
aggr = self.aggregate(*self._step_vals)
|
||||
self.reset()
|
||||
return self.compute(self, None, aggr)
|
||||
|
||||
def reset(self):
|
||||
self._step_vals = []
|
||||
|
||||
|
||||
class TensorMetric(Metric):
|
||||
"""
|
||||
|
@ -159,91 +207,20 @@ class TensorMetric(Metric):
|
|||
Already handles DDP sync and input/output conversions.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str,
|
||||
reduce_group: Optional[Any] = None,
|
||||
reduce_op: Optional[Any] = None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
name: the metric's name
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||
Defaults to sum.
|
||||
"""
|
||||
super().__init__(name)
|
||||
self.reduce_group = reduce_group
|
||||
self.reduce_op = reduce_op
|
||||
|
||||
@staticmethod
|
||||
def input_convert(self, data: Any):
|
||||
return apply_to_collection(data,
|
||||
(torch.Tensor, np.ndarray, numbers.Number),
|
||||
convert_to_tensor,
|
||||
self.dtype, self.device)
|
||||
data = apply_to_collection(
|
||||
data, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device
|
||||
)
|
||||
return super(TensorMetric, self).input_convert(self, data)
|
||||
|
||||
@staticmethod
|
||||
def output_convert(self, data: Any, output: Any):
|
||||
return apply_to_collection(output, torch.Tensor, convert_to_tensor,
|
||||
self.dtype, self.device)
|
||||
|
||||
@staticmethod
|
||||
def ddp_sync(self, data: Any, output: Any):
|
||||
return apply_to_collection(output, torch.Tensor, sync_ddp_if_available,
|
||||
self.reduce_group, self.reduce_op)
|
||||
|
||||
|
||||
class TensorCollectionMetric(Metric):
|
||||
"""
|
||||
Base class for metric implementation operating directly on tensors.
|
||||
All inputs will be casted to tensors if necessary. Outputs won't be casted.
|
||||
Already handles DDP sync and input conversions.
|
||||
|
||||
This class differs from :class:`TensorMetric`, as it assumes all outputs to
|
||||
be collections of tensors and does not explicitly convert them. This is
|
||||
necessary, since some collections (like for ROC, Precision-Recall Curve etc.)
|
||||
cannot be converted to tensors at the highest level.
|
||||
All numpy arrays and numbers occuring in these outputs will still be converted.
|
||||
|
||||
Use this class as a baseclass, whenever you want to ensure inputs are
|
||||
tensors and outputs cannot be converted to tensors automatically
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, name: str,
|
||||
reduce_group: Optional[Any] = None,
|
||||
reduce_op: Optional[Any] = None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
name: the metric's name
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||
Defaults to sum.
|
||||
"""
|
||||
super().__init__(name)
|
||||
self.reduce_group = reduce_group
|
||||
self.reduce_op = reduce_op
|
||||
|
||||
@staticmethod
|
||||
def input_convert(self, data: Any):
|
||||
return apply_to_collection(data,
|
||||
(torch.Tensor, np.ndarray, numbers.Number),
|
||||
convert_to_tensor,
|
||||
self.dtype, self.device)
|
||||
|
||||
@staticmethod
|
||||
def output_convert(self, data: Any, output: Any):
|
||||
return apply_to_collection(output,
|
||||
(torch.Tensor, np.ndarray, numbers.Number),
|
||||
convert_to_tensor,
|
||||
self.dtype, self.device)
|
||||
|
||||
@staticmethod
|
||||
def ddp_sync(self, data: Any, output: Any):
|
||||
return apply_to_collection(output, torch.Tensor, sync_ddp_if_available,
|
||||
self.reduce_group, self.reduce_op)
|
||||
output = apply_to_collection(
|
||||
output, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device
|
||||
)
|
||||
return super(TensorMetric, self).output_convert(self, data, output)
|
||||
|
||||
|
||||
class NumpyMetric(Metric):
|
||||
|
@ -254,36 +231,15 @@ class NumpyMetric(Metric):
|
|||
Already handles DDP sync and input/output conversions.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str,
|
||||
reduce_group: Optional[Any] = None,
|
||||
reduce_op: Optional[Any] = None):
|
||||
"""
|
||||
|
||||
Args:
|
||||
name: the metric's name
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
|
||||
Defaults to sum.
|
||||
"""
|
||||
super().__init__(name)
|
||||
self.reduce_group = reduce_group
|
||||
self.reduce_op = reduce_op
|
||||
|
||||
@staticmethod
|
||||
def input_convert(self, data: Any):
|
||||
return apply_to_collection(data,
|
||||
(torch.Tensor, np.ndarray, numbers.Number),
|
||||
convert_to_numpy)
|
||||
data = apply_to_collection(data, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)
|
||||
return super(NumpyMetric, self).input_convert(self, data)
|
||||
|
||||
@staticmethod
|
||||
def output_convert(self, data: Any, output: Any):
|
||||
return apply_to_collection(output,
|
||||
(torch.Tensor, np.ndarray, numbers.Number),
|
||||
convert_to_tensor,
|
||||
self.dtype, self.device)
|
||||
output = apply_to_collection(
|
||||
output, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def ddp_sync(self, data: Any, output: Any):
|
||||
return apply_to_collection(output, torch.Tensor, sync_ddp_if_available,
|
||||
self.reduce_group, self.reduce_op)
|
||||
return super(NumpyMetric, self).output_convert(self, data, output)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,48 +1,49 @@
|
|||
import os
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import tests.base.develop_utils as tutils
|
||||
from tests.base import EvalModelTemplate
|
||||
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric, TensorCollectionMetric
|
||||
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
|
||||
class DummyTensorMetric(TensorMetric):
|
||||
def __init__(self):
|
||||
super().__init__('dummy')
|
||||
super().__init__("dummy")
|
||||
|
||||
def forward(self, input1, input2):
|
||||
assert isinstance(input1, torch.Tensor)
|
||||
assert isinstance(input2, torch.Tensor)
|
||||
return torch.tensor([1.])
|
||||
return torch.tensor([1.0])
|
||||
|
||||
|
||||
class DummyNumpyMetric(NumpyMetric):
|
||||
def __init__(self):
|
||||
super().__init__('dummy')
|
||||
super().__init__("dummy")
|
||||
|
||||
def forward(self, input1, input2):
|
||||
assert isinstance(input1, np.ndarray)
|
||||
assert isinstance(input2, np.ndarray)
|
||||
return 1.
|
||||
return 1.0
|
||||
|
||||
|
||||
class DummyTensorCollectionMetric(TensorCollectionMetric):
|
||||
class DummyTensorCollectionMetric(TensorMetric):
|
||||
def __init__(self):
|
||||
super().__init__('dummy')
|
||||
super().__init__("dummy")
|
||||
|
||||
def forward(self, input1, input2):
|
||||
assert isinstance(input1, torch.Tensor)
|
||||
assert isinstance(input2, torch.Tensor)
|
||||
return 1., 2., 3., 4.
|
||||
return 1.0, 2.0, 3.0, 4.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize('metric', [DummyTensorCollectionMetric()])
|
||||
@pytest.mark.parametrize("metric", [DummyTensorCollectionMetric()])
|
||||
def test_collection_metric(metric: Metric):
|
||||
""" Test that metric.device, metric.dtype works for metric collection """
|
||||
input1, input2 = torch.tensor([1.]), torch.tensor([2.])
|
||||
input1, input2 = torch.tensor([1.0]), torch.tensor([2.0])
|
||||
|
||||
def change_and_check_device_dtype(device, dtype):
|
||||
metric.to(device=device, dtype=dtype)
|
||||
|
@ -56,9 +57,9 @@ def test_collection_metric(metric: Metric):
|
|||
if dtype is not None:
|
||||
assert metric.dtype == dtype
|
||||
|
||||
devices = [None, 'cpu']
|
||||
devices = [None, "cpu"]
|
||||
if torch.cuda.is_available():
|
||||
devices += ['cuda:0']
|
||||
devices += ["cuda:0"]
|
||||
|
||||
for device in devices:
|
||||
for dtype in [None, torch.float32, torch.float64]:
|
||||
|
@ -66,10 +67,10 @@ def test_collection_metric(metric: Metric):
|
|||
|
||||
if torch.cuda.is_available():
|
||||
metric.cuda(0)
|
||||
assert metric.device == torch.device('cuda', index=0)
|
||||
assert metric.device == torch.device("cuda", index=0)
|
||||
|
||||
metric.cpu()
|
||||
assert metric.device == torch.device('cpu')
|
||||
assert metric.device == torch.device("cpu")
|
||||
|
||||
metric.type(torch.int8)
|
||||
assert metric.dtype == torch.int8
|
||||
|
@ -87,13 +88,16 @@ def test_collection_metric(metric: Metric):
|
|||
assert metric.dtype == torch.float16
|
||||
|
||||
|
||||
@pytest.mark.parametrize('metric', [
|
||||
DummyTensorMetric(),
|
||||
DummyNumpyMetric(),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"metric",
|
||||
[
|
||||
DummyTensorMetric(),
|
||||
DummyNumpyMetric(),
|
||||
],
|
||||
)
|
||||
def test_metric(metric: Metric):
|
||||
""" Test that metric.device, metric.dtype works for single metric"""
|
||||
input1, input2 = torch.tensor([1.]), torch.tensor([2.])
|
||||
input1, input2 = torch.tensor([1.0]), torch.tensor([2.0])
|
||||
|
||||
def change_and_check_device_dtype(device, dtype):
|
||||
metric.to(device=device, dtype=dtype)
|
||||
|
@ -109,9 +113,9 @@ def test_metric(metric: Metric):
|
|||
assert metric.dtype == dtype
|
||||
assert metric_val.dtype == dtype
|
||||
|
||||
devices = [None, 'cpu']
|
||||
devices = [None, "cpu"]
|
||||
if torch.cuda.is_available():
|
||||
devices += ['cuda:0']
|
||||
devices += ["cuda:0"]
|
||||
|
||||
for device in devices:
|
||||
for dtype in [None, torch.float32, torch.float64]:
|
||||
|
@ -119,16 +123,12 @@ def test_metric(metric: Metric):
|
|||
|
||||
if torch.cuda.is_available():
|
||||
metric.cuda(0)
|
||||
assert metric.device == torch.device('cuda', index=0)
|
||||
assert metric(input1, input2).device == torch.device('cuda', index=0)
|
||||
assert metric.device == torch.device("cuda", index=0)
|
||||
assert metric(input1, input2).device == torch.device("cuda", index=0)
|
||||
|
||||
metric.cpu()
|
||||
assert metric.device == torch.device('cpu')
|
||||
assert metric(input1, input2).device == torch.device('cpu')
|
||||
|
||||
metric.type(torch.int8)
|
||||
assert metric.dtype == torch.int8
|
||||
assert metric(input1, input2).dtype == torch.int8
|
||||
assert metric.device == torch.device("cpu")
|
||||
assert metric(input1, input2).device == torch.device("cpu")
|
||||
|
||||
metric.float()
|
||||
assert metric.dtype == torch.float32
|
||||
|
@ -156,7 +156,7 @@ def test_model_pickable(tmpdir, metric: Metric):
|
|||
max_epochs=1,
|
||||
limit_train_batches=10,
|
||||
gpus=[0, 1],
|
||||
distributed_backend='ddp_spawn',
|
||||
distributed_backend="ddp_spawn",
|
||||
)
|
||||
|
||||
model = EvalModelTemplate()
|
||||
|
@ -167,17 +167,19 @@ def test_model_pickable(tmpdir, metric: Metric):
|
|||
result = trainer.fit(model)
|
||||
|
||||
# correct result and ok accuracy
|
||||
assert result == 1, 'ddp model failed to complete'
|
||||
assert result == 1, "ddp model failed to complete"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("metric", [DummyTensorMetric(), DummyNumpyMetric()])
|
||||
def test_saving_pickable(tmpdir, metric: Metric):
|
||||
""" Make sure that metrics are pickable by saving and loading them using torch """
|
||||
x, y = torch.randn(10,), torch.randn(10,)
|
||||
x, y = torch.randn(10,), torch.randn(
|
||||
10,
|
||||
)
|
||||
results_before_save = metric(x, y)
|
||||
|
||||
# save metric
|
||||
save_path = os.path.join(tmpdir, 'save_test.ckpt')
|
||||
save_path = os.path.join(tmpdir, "save_test.ckpt")
|
||||
torch.save(metric, save_path)
|
||||
|
||||
# load metric
|
||||
|
@ -186,3 +188,125 @@ def test_saving_pickable(tmpdir, metric: Metric):
|
|||
|
||||
# Check metric value is the same
|
||||
assert results_before_save == results_after_load
|
||||
|
||||
|
||||
def check_call_order():
|
||||
class DummyMetric(Metric):
|
||||
def __init__(self):
|
||||
super().__init__("dummy")
|
||||
self.call_history = ["init"]
|
||||
|
||||
@staticmethod
|
||||
def input_convert(self, data: Any):
|
||||
self.call_history.append("input_convert")
|
||||
return super(DummyMetric, self).input_convert(self, data)
|
||||
|
||||
def forward(self, tensor1, tensor2):
|
||||
self.call_history.append("forward")
|
||||
return tensor1 - tensor2
|
||||
|
||||
@staticmethod
|
||||
def output_convert(self, data: Any, output: Any):
|
||||
self.call_history.append("output_convert")
|
||||
return super(DummyMetric, self).output_convert(self, data, output)
|
||||
|
||||
def ddp_sync(self, tensor: Any):
|
||||
self.call_history.append("ddp_sync")
|
||||
return super().ddp_sync(tensor)
|
||||
|
||||
@staticmethod
|
||||
def ddp_reduce(self, data: Any, output: Any):
|
||||
self.call_history.append("ddp_reduce")
|
||||
return super(DummyMetric, self).ddp_reduce(self, data, output)
|
||||
|
||||
def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor:
|
||||
self.call_history.append("aggregate")
|
||||
return super().aggregate(*tensors)
|
||||
|
||||
def reset(self):
|
||||
self.call_history.append("reset")
|
||||
return super().reset()
|
||||
|
||||
@property
|
||||
def aggregated(self) -> torch.Tensor:
|
||||
self.call_history.append("aggregated")
|
||||
return super().aggregated
|
||||
|
||||
@staticmethod
|
||||
def compute(self, data: Any, output: Any):
|
||||
self.call_history.append("compute")
|
||||
return super(DummyMetric, self).compute(self, data, output)
|
||||
|
||||
metric = DummyMetric()
|
||||
assert metric.call_history == ["init"]
|
||||
result = metric(torch.tensor([2.0]), torch.tensor([1.0]))
|
||||
assert torch.allclose(result, torch.tensor(1.0))
|
||||
assert metric.call_history == [
|
||||
"init",
|
||||
"input_convert",
|
||||
"forward",
|
||||
"output_convert",
|
||||
"ddp_reduce",
|
||||
"ddp_sync",
|
||||
"aggregate",
|
||||
]
|
||||
aggr = metric.aggregated
|
||||
assert metric.call_history == [
|
||||
"init",
|
||||
"input_convert",
|
||||
"forward",
|
||||
"output_convert",
|
||||
"ddp_reduce",
|
||||
"ddp_sync",
|
||||
"aggregate",
|
||||
"aggregated",
|
||||
"aggregate",
|
||||
"reset",
|
||||
]
|
||||
assert torch.allclose(aggr, result)
|
||||
_ = metric(torch.tensor(2.0), torch.tensor(1.0))
|
||||
assert metric.call_history == [
|
||||
"init",
|
||||
"input_convert",
|
||||
"forward",
|
||||
"output_convert",
|
||||
"ddp_reduce",
|
||||
"ddp_sync",
|
||||
"aggregate",
|
||||
"aggregated",
|
||||
"aggregate",
|
||||
"reset",
|
||||
"input_convert",
|
||||
"forward",
|
||||
"output_convert",
|
||||
"ddp_reduce",
|
||||
"ddp_sync",
|
||||
"aggregate",
|
||||
]
|
||||
|
||||
metric = DummyMetric()
|
||||
_ = metric(torch.tensor([2.0]), torch.tensor([1.0]))
|
||||
_ = metric(torch.tensor([3.0]), torch.tensor([0.0]))
|
||||
|
||||
aggregated = metric.aggregated
|
||||
|
||||
assert torch.allclose(aggregated, torch.tensor(2.0))
|
||||
|
||||
assert metric.call_history == [
|
||||
"init",
|
||||
"input_convert",
|
||||
"forward",
|
||||
"output_convert",
|
||||
"ddp_reduce",
|
||||
"ddp_sync",
|
||||
"aggregate",
|
||||
"input_convert",
|
||||
"forward",
|
||||
"output_convert",
|
||||
"ddp_reduce",
|
||||
"ddp_sync",
|
||||
"aggregate",
|
||||
"aggregated",
|
||||
"aggregate",
|
||||
"reset",
|
||||
]
|
||||
|
|
|
@ -167,13 +167,12 @@ def test_sklearn_metric(metric_class, sklearn_func, inputs):
|
|||
|
||||
sklearn_result = sklearn_func(**numpy_inputs)
|
||||
lightning_result = metric_class(**inputs)
|
||||
assert np.allclose(sklearn_result, lightning_result, atol=1e-5)
|
||||
|
||||
sklearn_result = apply_to_collection(
|
||||
sklearn_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)
|
||||
|
||||
lightning_result = apply_to_collection(
|
||||
lightning_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)
|
||||
lightning_result = np.array(apply_to_collection(
|
||||
lightning_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy))
|
||||
|
||||
assert np.allclose(sklearn_result, lightning_result, atol=1e-5)
|
||||
assert isinstance(lightning_result, type(sklearn_result))
|
||||
|
|
Loading…
Reference in New Issue