Metric aggregation (#3321)

* metric aggregation

* metric aggregation

* add at_least_1d

* fix output formatting

* add metric tests

* add missing test case

* remove reduce_op frm metric classes

* fix reduce_op stuff

* start test fixing

* fix tests due to aggregation

* fix faulty import

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* remove reduce_op docstrings

* add compute

* remove import

* remove collection metric

* update base class

* update tests

* Update metric.py

* Update metric.py

* Apply suggestions from code review

* change default aggregate

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
This commit is contained in:
Justus Schock 2020-09-14 13:23:11 +02:00 committed by GitHub
parent 50b8388f03
commit 4dc4c8cfa5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 735 additions and 827 deletions

View File

@ -30,9 +30,9 @@ from pytorch_lightning.metrics.functional.classification import (
precision,
precision_recall_curve,
recall,
roc
roc,
)
from pytorch_lightning.metrics.metric import TensorCollectionMetric, TensorMetric
from pytorch_lightning.metrics.metric import TensorMetric
class Accuracy(TensorMetric):
@ -44,17 +44,16 @@ class Accuracy(TensorMetric):
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = Accuracy()
>>> metric(pred, target)
tensor(0.7500)
>>> metric(pred, target).item()
0.75
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
self,
num_classes: Optional[int] = None,
reduction: str = "elementwise_mean",
reduce_group: Any = None,
):
"""
Args:
@ -65,11 +64,8 @@ class Accuracy(TensorMetric):
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='accuracy',
reduce_group=reduce_group,
reduce_op=reduce_op)
super().__init__(name="accuracy", reduce_group=reduce_group)
self.num_classes = num_classes
self.reduction = reduction
@ -84,8 +80,7 @@ class Accuracy(TensorMetric):
Return:
A Tensor with the classification score.
"""
return accuracy(pred=pred, target=target,
num_classes=self.num_classes, reduction=self.reduction)
return accuracy(pred=pred, target=target, num_classes=self.num_classes, reduction=self.reduction)
class ConfusionMatrix(TensorMetric):
@ -106,22 +101,21 @@ class ConfusionMatrix(TensorMetric):
"""
def __init__(
self,
num_classes: Optional[int] = None,
normalize: bool = False,
reduce_group: Any = None,
reduce_op: Any = None,
self,
num_classes: Optional[int] = None,
normalize: bool = False,
reduce_group: Any = None,
):
"""
Args:
num_classes: number of classes
normalize: whether to compute a normalized confusion matrix
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='confusion_matrix',
reduce_group=reduce_group,
reduce_op=reduce_op)
super().__init__(
name="confusion_matrix",
reduce_group=reduce_group,
)
self.normalize = normalize
self.num_classes = num_classes
@ -140,8 +134,16 @@ class ConfusionMatrix(TensorMetric):
normalize=self.normalize,
num_classes=self.num_classes)
def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor:
"""Aggregates results by stacking them instead of concatenating before averaging.
class PrecisionRecallCurve(TensorCollectionMetric):
Returns:
the aggregated results
"""
return torch.stack(tensors).mean(0)
class PrecisionRecallCurve(TensorMetric):
"""
Computes the precision recall curve
@ -161,28 +163,27 @@ class PrecisionRecallCurve(TensorCollectionMetric):
"""
def __init__(
self,
pos_label: int = 1,
reduce_group: Any = None,
reduce_op: Any = None,
self,
pos_label: int = 1,
reduce_group: Any = None,
):
"""
Args:
pos_label: positive label indicator
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='precision_recall_curve',
reduce_group=reduce_group,
reduce_op=reduce_op)
super().__init__(
name="precision_recall_curve",
reduce_group=reduce_group,
)
self.pos_label = pos_label
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Actual metric computation
@ -197,9 +198,7 @@ class PrecisionRecallCurve(TensorCollectionMetric):
- recall values
- threshold values
"""
return precision_recall_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=self.pos_label)
return precision_recall_curve(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label)
class Precision(TensorMetric):
@ -217,11 +216,10 @@ class Precision(TensorMetric):
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
self,
num_classes: Optional[int] = None,
reduction: str = "elementwise_mean",
reduce_group: Any = None,
):
"""
Args:
@ -232,11 +230,11 @@ class Precision(TensorMetric):
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='precision',
reduce_group=reduce_group,
reduce_op=reduce_op)
super().__init__(
name="precision",
reduce_group=reduce_group,
)
self.num_classes = num_classes
self.reduction = reduction
@ -251,9 +249,7 @@ class Precision(TensorMetric):
Return:
A Tensor with the classification score.
"""
return precision(pred=pred, target=target,
num_classes=self.num_classes,
reduction=self.reduction)
return precision(pred=pred, target=target, num_classes=self.num_classes, reduction=self.reduction)
class Recall(TensorMetric):
@ -271,11 +267,10 @@ class Recall(TensorMetric):
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
self,
num_classes: Optional[int] = None,
reduction: str = "elementwise_mean",
reduce_group: Any = None,
):
"""
Args:
@ -286,11 +281,11 @@ class Recall(TensorMetric):
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='recall',
reduce_group=reduce_group,
reduce_op=reduce_op)
super().__init__(
name="recall",
reduce_group=reduce_group,
)
self.num_classes = num_classes
self.reduction = reduction
@ -306,10 +301,7 @@ class Recall(TensorMetric):
Return:
A Tensor with the classification score.
"""
return recall(pred=pred,
target=target,
num_classes=self.num_classes,
reduction=self.reduction)
return recall(pred=pred, target=target, num_classes=self.num_classes, reduction=self.reduction)
class AveragePrecision(TensorMetric):
@ -327,28 +319,24 @@ class AveragePrecision(TensorMetric):
"""
def __init__(
self,
pos_label: int = 1,
reduce_group: Any = None,
reduce_op: Any = None,
self,
pos_label: int = 1,
reduce_group: Any = None,
):
"""
Args:
pos_label: positive label indicator
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='AP',
reduce_group=reduce_group,
reduce_op=reduce_op)
super().__init__(
name="AP",
reduce_group=reduce_group,
)
self.pos_label = pos_label
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None
self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None
) -> torch.Tensor:
"""
Actual metric computation
@ -361,9 +349,7 @@ class AveragePrecision(TensorMetric):
Return:
torch.Tensor: classification score
"""
return average_precision(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=self.pos_label)
return average_precision(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label)
class AUROC(TensorMetric):
@ -381,28 +367,24 @@ class AUROC(TensorMetric):
"""
def __init__(
self,
pos_label: int = 1,
reduce_group: Any = None,
reduce_op: Any = None,
self,
pos_label: int = 1,
reduce_group: Any = None,
):
"""
Args:
pos_label: positive label indicator
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='auroc',
reduce_group=reduce_group,
reduce_op=reduce_op)
super().__init__(
name="auroc",
reduce_group=reduce_group,
)
self.pos_label = pos_label
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None
self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None
) -> torch.Tensor:
"""
Actual metric computation
@ -415,9 +397,7 @@ class AUROC(TensorMetric):
Return:
torch.Tensor: classification score
"""
return auroc(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=self.pos_label)
return auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label)
class FBeta(TensorMetric):
@ -435,12 +415,11 @@ class FBeta(TensorMetric):
"""
def __init__(
self,
beta: float,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
self,
beta: float,
num_classes: Optional[int] = None,
reduction: str = "elementwise_mean",
reduce_group: Any = None,
):
"""
Args:
@ -452,11 +431,11 @@ class FBeta(TensorMetric):
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for DDP reduction
"""
super().__init__(name='fbeta',
reduce_group=reduce_group,
reduce_op=reduce_op)
super().__init__(
name="fbeta",
reduce_group=reduce_group,
)
self.beta = beta
self.num_classes = num_classes
@ -473,9 +452,9 @@ class FBeta(TensorMetric):
Return:
torch.Tensor: classification score
"""
return fbeta_score(pred=pred, target=target,
beta=self.beta, num_classes=self.num_classes,
reduction=self.reduction)
return fbeta_score(
pred=pred, target=target, beta=self.beta, num_classes=self.num_classes, reduction=self.reduction
)
class F1(TensorMetric):
@ -493,11 +472,10 @@ class F1(TensorMetric):
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
self,
num_classes: Optional[int] = None,
reduction: str = "elementwise_mean",
reduce_group: Any = None,
):
"""
Args:
@ -508,11 +486,11 @@ class F1(TensorMetric):
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='f1',
reduce_group=reduce_group,
reduce_op=reduce_op)
super().__init__(
name="f1",
reduce_group=reduce_group,
)
self.num_classes = num_classes
self.reduction = reduction
@ -528,12 +506,10 @@ class F1(TensorMetric):
Return:
torch.Tensor: classification score
"""
return f1_score(pred=pred, target=target,
num_classes=self.num_classes,
reduction=self.reduction)
return f1_score(pred=pred, target=target, num_classes=self.num_classes, reduction=self.reduction)
class ROC(TensorCollectionMetric):
class ROC(TensorMetric):
"""
Computes the Receiver Operator Characteristic (ROC)
@ -553,28 +529,24 @@ class ROC(TensorCollectionMetric):
"""
def __init__(
self,
pos_label: int = 1,
reduce_group: Any = None,
reduce_op: Any = None,
self,
pos_label: int = 1,
reduce_group: Any = None,
):
"""
Args:
pos_label: positive label indicator
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='roc',
reduce_group=reduce_group,
reduce_op=reduce_op)
super().__init__(
name="roc",
reduce_group=reduce_group,
)
self.pos_label = pos_label
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None
self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Actual metric computation
@ -589,12 +561,10 @@ class ROC(TensorCollectionMetric):
- true positive rate
- thresholds
"""
return roc(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=self.pos_label)
return roc(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label)
class MulticlassROC(TensorCollectionMetric):
class MulticlassROC(TensorMetric):
"""
Computes the multiclass ROC
@ -615,27 +585,27 @@ class MulticlassROC(TensorCollectionMetric):
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduce_group: Any = None,
reduce_op: Any = None,
self,
num_classes: Optional[int] = None,
reduce_group: Any = None,
):
"""
Args:
num_classes: number of classes
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='multiclass_roc',
reduce_group=reduce_group,
reduce_op=reduce_op)
super().__init__(
name="multiclass_roc",
reduce_group=reduce_group,
)
self.num_classes = num_classes
def forward(
self, pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Actual metric computation
@ -649,13 +619,19 @@ class MulticlassROC(TensorCollectionMetric):
tuple: A tuple consisting of one tuple per class, holding false positive rate, true positive rate and thresholds
"""
return multiclass_roc(pred=pred,
target=target,
sample_weight=sample_weight,
num_classes=self.num_classes)
return multiclass_roc(pred=pred, target=target, sample_weight=sample_weight, num_classes=self.num_classes)
def aggregate(self, *tensors: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Aggregates results by stacking them instead of concatenating before averaging.
Returns:
the aggregated results
"""
return tuple([tuple([torch.stack(tmps).mean(0) for tmps in zip(*_tensors)]) for _tensors in zip(*tensors)])
class MulticlassPrecisionRecallCurve(TensorCollectionMetric):
class MulticlassPrecisionRecallCurve(TensorMetric):
"""Computes the multiclass PR Curve
Example:
@ -674,29 +650,28 @@ class MulticlassPrecisionRecallCurve(TensorCollectionMetric):
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduce_group: Any = None,
reduce_op: Any = None,
self,
num_classes: Optional[int] = None,
reduce_group: Any = None,
):
"""
Args:
num_classes: number of classes
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='multiclass_precision_recall_curve',
reduce_group=reduce_group,
reduce_op=reduce_op)
super().__init__(
name="multiclass_precision_recall_curve",
reduce_group=reduce_group,
)
self.num_classes = num_classes
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Actual metric computation
@ -710,10 +685,18 @@ class MulticlassPrecisionRecallCurve(TensorCollectionMetric):
tuple: A tuple consisting of one tuple per class, holding precision, recall and thresholds
"""
return multiclass_precision_recall_curve(pred=pred,
target=target,
sample_weight=sample_weight,
num_classes=self.num_classes)
return multiclass_precision_recall_curve(
pred=pred, target=target, sample_weight=sample_weight, num_classes=self.num_classes
)
def aggregate(self, *tensors: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Aggregates results by stacking them instead of concatenating before averaging.
Returns:
the aggregated results
"""
return tuple([tuple([torch.stack(tmps).mean(0) for tmps in zip(*_tensors)]) for _tensors in zip(*tensors)])
class DiceCoefficient(TensorMetric):
@ -733,12 +716,12 @@ class DiceCoefficient(TensorMetric):
"""
def __init__(
self,
include_background: bool = False,
nan_score: float = 0.0, no_fg_score: float = 0.0,
reduction: str = 'elementwise_mean',
reduce_group: Any = None,
reduce_op: Any = None,
self,
include_background: bool = False,
nan_score: float = 0.0,
no_fg_score: float = 0.0,
reduction: str = "elementwise_mean",
reduce_group: Any = None,
):
"""
Args:
@ -751,11 +734,11 @@ class DiceCoefficient(TensorMetric):
- none: pass array
- sum: add elements
reduce_group: the process group to reduce metric results from DDP
reduce_op: the operation to perform for ddp reduction
"""
super().__init__(name='dice',
reduce_group=reduce_group,
reduce_op=reduce_op)
super().__init__(
name="dice",
reduce_group=reduce_group,
)
self.include_background = include_background
self.nan_score = nan_score
@ -773,12 +756,14 @@ class DiceCoefficient(TensorMetric):
Return:
torch.Tensor: the calculated dice coefficient
"""
return dice_score(pred=pred,
target=target,
bg=self.include_background,
nan_score=self.nan_score,
no_fg_score=self.no_fg_score,
reduction=self.reduction)
return dice_score(
pred=pred,
target=target,
bg=self.include_background,
nan_score=self.nan_score,
no_fg_score=self.no_fg_score,
reduction=self.reduction,
)
class IoU(TensorMetric):
@ -799,11 +784,7 @@ class IoU(TensorMetric):
"""
def __init__(
self,
remove_bg: bool = False,
reduction: str = 'elementwise_mean'
):
def __init__(self, remove_bg: bool = False, reduction: str = "elementwise_mean"):
"""
Args:
remove_bg: Flag to state whether a background class has been included
@ -817,12 +798,11 @@ class IoU(TensorMetric):
- none: pass array
- sum: add elements
"""
super().__init__(name='iou')
super().__init__(name="iou")
self.remove_bg = remove_bg
self.reduction = reduction
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor,
sample_weight: Optional[torch.Tensor] = None):
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor, sample_weight: Optional[torch.Tensor] = None):
"""
Actual metric calculation.
"""

View File

@ -18,6 +18,7 @@ conversion to/from :class:`numpy.ndarray` and :class:`torch.Tensor` as well as u
sync tensors between different processes in a DDP scenario, when needed.
"""
from functools import reduce
import numbers
from typing import Any, Callable, Optional, Union
@ -31,10 +32,11 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection
try:
from torch.distributed import ReduceOp
except ImportError:
class ReduceOp:
SUM = None
rank_zero_warn('Unsupported `ReduceOp` for distributed computing')
rank_zero_warn("Unsupported `ReduceOp` for distributed computing")
def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
@ -138,8 +140,9 @@ def _numpy_metric_input_conversion(func_to_decorate: Callable) -> Callable:
Return:
Callable: the decorated function
"""
return _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)(func_to_decorate)
return _apply_to_inputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)(
func_to_decorate
)
def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable:
@ -185,8 +188,9 @@ def _tensor_metric_input_conversion(func_to_decorate: Callable) -> Callable:
Return:
Callable: the decorated function
"""
return _apply_to_inputs(
apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(func_to_decorate)
return _apply_to_inputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(
func_to_decorate
)
def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> Callable:
@ -199,8 +203,9 @@ def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> C
Return:
Callable: the decorated function
"""
return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number),
convert_to_tensor)(func_to_decorate)
return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(
func_to_decorate
)
def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
@ -240,10 +245,9 @@ def _tensor_collection_metric_conversion(func_to_decorate: Callable) -> Callable
return _tensor_collection_metric_output_conversion(func_convert_inputs)
def sync_ddp_if_available(result: Union[torch.Tensor],
group: Optional[Any] = None,
reduce_op: Optional[ReduceOp] = None
) -> torch.Tensor:
def sync_ddp_if_available(
result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> torch.Tensor:
"""
Function to reduce the tensors from several ddp processes to one master process
@ -265,14 +269,13 @@ def sync_ddp_if_available(result: Union[torch.Tensor],
if reduce_op is None:
reduce_op = torch.distributed.ReduceOp.SUM
elif isinstance(reduce_op, str) and reduce_op in ('avg', 'mean'):
elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
reduce_op = torch.distributed.ReduceOp.SUM
divide_by_world_size = True
# sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=reduce_op, group=group,
async_op=False)
torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False)
if divide_by_world_size:
result = result / torch.distributed.get_world_size(group)
@ -280,8 +283,21 @@ def sync_ddp_if_available(result: Union[torch.Tensor],
return result
def gather_all_tensors_if_available(result: Union[torch.Tensor],
group: Optional[Any] = None):
def at_least_1d(tensor: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
"""Makes sure the tensor is at least of 1d shape
Args:
tensor: the tensor or array to check the shape for
Returns:
the optionally reshaped tensor
"""
if tensor.shape == ():
tensor = tensor.reshape(1, )
return tensor
def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None):
"""
Function to gather all tensors from several ddp processes onto a list that
is broadcasted to all processes
@ -312,8 +328,7 @@ def gather_all_tensors_if_available(result: Union[torch.Tensor],
return result
def sync_ddp(group: Optional[Any] = None,
reduce_op: Optional[ReduceOp] = None) -> Callable:
def sync_ddp(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
"""
This decorator syncs a functions outputs across different processes for DDP.
@ -327,15 +342,14 @@ def sync_ddp(group: Optional[Any] = None,
"""
def decorator_fn(func_to_decorate):
return _apply_to_outputs(apply_to_collection, torch.Tensor,
sync_ddp_if_available, group=group,
reduce_op=reduce_op)(func_to_decorate)
return _apply_to_outputs(
apply_to_collection, torch.Tensor, sync_ddp_if_available, group=group, reduce_op=reduce_op
)(func_to_decorate)
return decorator_fn
def numpy_metric(group: Optional[Any] = None,
reduce_op: Optional[ReduceOp] = None) -> Callable:
def numpy_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on numpy arrays.
It handles the argument conversion and DDP reduction for metrics working on numpy.
@ -357,8 +371,7 @@ def numpy_metric(group: Optional[Any] = None,
return decorator_fn
def tensor_metric(group: Optional[Any] = None,
reduce_op: Optional[ReduceOp] = None) -> Callable:
def tensor_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on tensors.
It handles the argument conversion and DDP reduction for metrics working on tensors.
@ -379,8 +392,7 @@ def tensor_metric(group: Optional[Any] = None,
return decorator_fn
def tensor_collection_metric(group: Optional[Any] = None,
reduce_op: Optional[ReduceOp] = None) -> Callable:
def tensor_collection_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on tensors and returning collections
that cannot be converted to tensors.

View File

@ -13,7 +13,7 @@
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any, Mapping, Optional, Sequence
import numbers
import torch
@ -21,8 +21,11 @@ from torch import nn
import numpy as np
from pytorch_lightning.metrics.converters import (
sync_ddp_if_available, gather_all_tensors_if_available,
convert_to_tensor, convert_to_numpy)
at_least_1d,
gather_all_tensors_if_available,
convert_to_tensor,
convert_to_numpy,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
@ -40,32 +43,41 @@ class Metric(DeviceDtypeModuleMixin, nn.Module, ABC):
* input_convert: pre-forward hook that takes care of input conversion
* output_convert: post-forward hook that takes care of output convertion
* ddp_sync: implementation of ddp sync, default is gather all
* aggregate: implement how values should be aggregated
* ddp_reduce: implementation of ddp sync + aggregation, default is ddp_sync + aggregate
* compute: post-ddp sync for additional metric computations
``ddp_reduce`` by default calls the following methods, which can also be overwritten if necessary.
* ddp_sync: implements how values should be synced across ddp-processes. Defaults to gather all.
* aggregate: implement how values should be aggregated (defaults to mean).
Call order
input_convert -> forward -> output_convert -> ddp_sync -> aggregate -> compute
input_convert -> forward -> output_convert -> ddp_reduce (per default being ddp_sync -> aggregate) -> compute
"""
def __init__(self, name: str):
def __init__(self, name: str, reduce_group: Optional[Any] = None):
"""
Args:
name: the metric's name
reduce_group: the process group for DDP reduces (only needed for DDP training).
Defaults to all processes (world)
"""
super().__init__()
self.name = name
self._dtype = torch.get_default_dtype()
self._device = torch.device('cpu')
self._device = torch.device("cpu")
self.reduce_group = reduce_group
self._step_vals = []
# Register hooks
self.register_forward_pre_hook(self.input_convert)
self.register_forward_hook(self.output_convert)
self.register_forward_hook(self.ddp_sync)
self.register_forward_hook(self.aggregate)
self.register_forward_hook(self.ddp_reduce)
self.register_forward_hook(self.compute)
@staticmethod
@ -104,12 +116,30 @@ class Metric(DeviceDtypeModuleMixin, nn.Module, ABC):
Returns:
casted outputs
"""
return output
return apply_to_collection(output, (torch.Tensor, np.ndarray), at_least_1d)
@staticmethod
def ddp_sync(self, data: Any, output: Any):
def ddp_sync(self, tensor: Any):
"""
Implement how the outputs from forward should be synced
(per default just gathers all of them and adds them to self._step_vals)
Args:
tensor: tensor to sync
Returns:
synced output
"""
gathered_tensors = apply_to_collection(tensor, torch.Tensor, gather_all_tensors_if_available, self.reduce_group)
self._step_vals.append(gathered_tensors)
return gathered_tensors
@staticmethod
def ddp_reduce(self, data: Any, output: Any):
"""
Implement how the outputs from forward should be synced and reduced across nodes
Args:
data: input to forward method
@ -119,27 +149,36 @@ class Metric(DeviceDtypeModuleMixin, nn.Module, ABC):
synced output
"""
return output
synced = self.ddp_sync(output)
return self.aggregate(synced)
@staticmethod
def aggregate(self, data: Any, output: Any):
def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor:
"""
Implement aggregation of values on the same device
Args:
data: input to forward method
output: output from the `ddp_sync` hook
tensors: the values to be aggregated
Returns:
aggregated values
"""
return output
try:
return torch.cat(tensors).mean(0)
except (ValueError, TypeError):
if isinstance(tensors[0], Mapping):
return {k: torch.stack([tensor[k] for tensor in tensors]).mean(0) for k in tensors[0].keys()}
elif isinstance(tensors[0], Sequence) and not isinstance(tensors[0], torch.Tensor):
return tuple([torch.stack(tmp).mean(0) for tmp in zip(*tensors)])
elif isinstance(tensors[0], torch.Tensor):
return torch.stack(tensors).mean(0)
else:
raise TypeError("unknown metric value format to aggregate")
@staticmethod
def compute(self, data: Any, output: Any):
"""
Implement additionally metric computations to be done after the ddp sync
Implement additionally metric computations to be done after the aggregation
Args:
data: input to forward method
@ -151,6 +190,15 @@ class Metric(DeviceDtypeModuleMixin, nn.Module, ABC):
"""
return output
@property
def aggregated(self) -> torch.Tensor:
aggr = self.aggregate(*self._step_vals)
self.reset()
return self.compute(self, None, aggr)
def reset(self):
self._step_vals = []
class TensorMetric(Metric):
"""
@ -159,91 +207,20 @@ class TensorMetric(Metric):
Already handles DDP sync and input/output conversions.
"""
def __init__(self, name: str,
reduce_group: Optional[Any] = None,
reduce_op: Optional[Any] = None):
"""
Args:
name: the metric's name
reduce_group: the process group for DDP reduces (only needed for DDP training).
Defaults to all processes (world)
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
Defaults to sum.
"""
super().__init__(name)
self.reduce_group = reduce_group
self.reduce_op = reduce_op
@staticmethod
def input_convert(self, data: Any):
return apply_to_collection(data,
(torch.Tensor, np.ndarray, numbers.Number),
convert_to_tensor,
self.dtype, self.device)
data = apply_to_collection(
data, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device
)
return super(TensorMetric, self).input_convert(self, data)
@staticmethod
def output_convert(self, data: Any, output: Any):
return apply_to_collection(output, torch.Tensor, convert_to_tensor,
self.dtype, self.device)
@staticmethod
def ddp_sync(self, data: Any, output: Any):
return apply_to_collection(output, torch.Tensor, sync_ddp_if_available,
self.reduce_group, self.reduce_op)
class TensorCollectionMetric(Metric):
"""
Base class for metric implementation operating directly on tensors.
All inputs will be casted to tensors if necessary. Outputs won't be casted.
Already handles DDP sync and input conversions.
This class differs from :class:`TensorMetric`, as it assumes all outputs to
be collections of tensors and does not explicitly convert them. This is
necessary, since some collections (like for ROC, Precision-Recall Curve etc.)
cannot be converted to tensors at the highest level.
All numpy arrays and numbers occuring in these outputs will still be converted.
Use this class as a baseclass, whenever you want to ensure inputs are
tensors and outputs cannot be converted to tensors automatically
"""
def __init__(self, name: str,
reduce_group: Optional[Any] = None,
reduce_op: Optional[Any] = None):
"""
Args:
name: the metric's name
reduce_group: the process group for DDP reduces (only needed for DDP training).
Defaults to all processes (world)
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
Defaults to sum.
"""
super().__init__(name)
self.reduce_group = reduce_group
self.reduce_op = reduce_op
@staticmethod
def input_convert(self, data: Any):
return apply_to_collection(data,
(torch.Tensor, np.ndarray, numbers.Number),
convert_to_tensor,
self.dtype, self.device)
@staticmethod
def output_convert(self, data: Any, output: Any):
return apply_to_collection(output,
(torch.Tensor, np.ndarray, numbers.Number),
convert_to_tensor,
self.dtype, self.device)
@staticmethod
def ddp_sync(self, data: Any, output: Any):
return apply_to_collection(output, torch.Tensor, sync_ddp_if_available,
self.reduce_group, self.reduce_op)
output = apply_to_collection(
output, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device
)
return super(TensorMetric, self).output_convert(self, data, output)
class NumpyMetric(Metric):
@ -254,36 +231,15 @@ class NumpyMetric(Metric):
Already handles DDP sync and input/output conversions.
"""
def __init__(self, name: str,
reduce_group: Optional[Any] = None,
reduce_op: Optional[Any] = None):
"""
Args:
name: the metric's name
reduce_group: the process group for DDP reduces (only needed for DDP training).
Defaults to all processes (world)
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
Defaults to sum.
"""
super().__init__(name)
self.reduce_group = reduce_group
self.reduce_op = reduce_op
@staticmethod
def input_convert(self, data: Any):
return apply_to_collection(data,
(torch.Tensor, np.ndarray, numbers.Number),
convert_to_numpy)
data = apply_to_collection(data, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)
return super(NumpyMetric, self).input_convert(self, data)
@staticmethod
def output_convert(self, data: Any, output: Any):
return apply_to_collection(output,
(torch.Tensor, np.ndarray, numbers.Number),
convert_to_tensor,
self.dtype, self.device)
output = apply_to_collection(
output, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device
)
@staticmethod
def ddp_sync(self, data: Any, output: Any):
return apply_to_collection(output, torch.Tensor, sync_ddp_if_available,
self.reduce_group, self.reduce_op)
return super(NumpyMetric, self).output_convert(self, data, output)

File diff suppressed because it is too large Load Diff

View File

@ -1,48 +1,49 @@
import os
from typing import Any
import numpy as np
import pytest
import torch
import tests.base.develop_utils as tutils
from tests.base import EvalModelTemplate
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric, TensorCollectionMetric
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
from pytorch_lightning import Trainer
class DummyTensorMetric(TensorMetric):
def __init__(self):
super().__init__('dummy')
super().__init__("dummy")
def forward(self, input1, input2):
assert isinstance(input1, torch.Tensor)
assert isinstance(input2, torch.Tensor)
return torch.tensor([1.])
return torch.tensor([1.0])
class DummyNumpyMetric(NumpyMetric):
def __init__(self):
super().__init__('dummy')
super().__init__("dummy")
def forward(self, input1, input2):
assert isinstance(input1, np.ndarray)
assert isinstance(input2, np.ndarray)
return 1.
return 1.0
class DummyTensorCollectionMetric(TensorCollectionMetric):
class DummyTensorCollectionMetric(TensorMetric):
def __init__(self):
super().__init__('dummy')
super().__init__("dummy")
def forward(self, input1, input2):
assert isinstance(input1, torch.Tensor)
assert isinstance(input2, torch.Tensor)
return 1., 2., 3., 4.
return 1.0, 2.0, 3.0, 4.0
@pytest.mark.parametrize('metric', [DummyTensorCollectionMetric()])
@pytest.mark.parametrize("metric", [DummyTensorCollectionMetric()])
def test_collection_metric(metric: Metric):
""" Test that metric.device, metric.dtype works for metric collection """
input1, input2 = torch.tensor([1.]), torch.tensor([2.])
input1, input2 = torch.tensor([1.0]), torch.tensor([2.0])
def change_and_check_device_dtype(device, dtype):
metric.to(device=device, dtype=dtype)
@ -56,9 +57,9 @@ def test_collection_metric(metric: Metric):
if dtype is not None:
assert metric.dtype == dtype
devices = [None, 'cpu']
devices = [None, "cpu"]
if torch.cuda.is_available():
devices += ['cuda:0']
devices += ["cuda:0"]
for device in devices:
for dtype in [None, torch.float32, torch.float64]:
@ -66,10 +67,10 @@ def test_collection_metric(metric: Metric):
if torch.cuda.is_available():
metric.cuda(0)
assert metric.device == torch.device('cuda', index=0)
assert metric.device == torch.device("cuda", index=0)
metric.cpu()
assert metric.device == torch.device('cpu')
assert metric.device == torch.device("cpu")
metric.type(torch.int8)
assert metric.dtype == torch.int8
@ -87,13 +88,16 @@ def test_collection_metric(metric: Metric):
assert metric.dtype == torch.float16
@pytest.mark.parametrize('metric', [
DummyTensorMetric(),
DummyNumpyMetric(),
])
@pytest.mark.parametrize(
"metric",
[
DummyTensorMetric(),
DummyNumpyMetric(),
],
)
def test_metric(metric: Metric):
""" Test that metric.device, metric.dtype works for single metric"""
input1, input2 = torch.tensor([1.]), torch.tensor([2.])
input1, input2 = torch.tensor([1.0]), torch.tensor([2.0])
def change_and_check_device_dtype(device, dtype):
metric.to(device=device, dtype=dtype)
@ -109,9 +113,9 @@ def test_metric(metric: Metric):
assert metric.dtype == dtype
assert metric_val.dtype == dtype
devices = [None, 'cpu']
devices = [None, "cpu"]
if torch.cuda.is_available():
devices += ['cuda:0']
devices += ["cuda:0"]
for device in devices:
for dtype in [None, torch.float32, torch.float64]:
@ -119,16 +123,12 @@ def test_metric(metric: Metric):
if torch.cuda.is_available():
metric.cuda(0)
assert metric.device == torch.device('cuda', index=0)
assert metric(input1, input2).device == torch.device('cuda', index=0)
assert metric.device == torch.device("cuda", index=0)
assert metric(input1, input2).device == torch.device("cuda", index=0)
metric.cpu()
assert metric.device == torch.device('cpu')
assert metric(input1, input2).device == torch.device('cpu')
metric.type(torch.int8)
assert metric.dtype == torch.int8
assert metric(input1, input2).dtype == torch.int8
assert metric.device == torch.device("cpu")
assert metric(input1, input2).device == torch.device("cpu")
metric.float()
assert metric.dtype == torch.float32
@ -156,7 +156,7 @@ def test_model_pickable(tmpdir, metric: Metric):
max_epochs=1,
limit_train_batches=10,
gpus=[0, 1],
distributed_backend='ddp_spawn',
distributed_backend="ddp_spawn",
)
model = EvalModelTemplate()
@ -167,17 +167,19 @@ def test_model_pickable(tmpdir, metric: Metric):
result = trainer.fit(model)
# correct result and ok accuracy
assert result == 1, 'ddp model failed to complete'
assert result == 1, "ddp model failed to complete"
@pytest.mark.parametrize("metric", [DummyTensorMetric(), DummyNumpyMetric()])
def test_saving_pickable(tmpdir, metric: Metric):
""" Make sure that metrics are pickable by saving and loading them using torch """
x, y = torch.randn(10,), torch.randn(10,)
x, y = torch.randn(10,), torch.randn(
10,
)
results_before_save = metric(x, y)
# save metric
save_path = os.path.join(tmpdir, 'save_test.ckpt')
save_path = os.path.join(tmpdir, "save_test.ckpt")
torch.save(metric, save_path)
# load metric
@ -186,3 +188,125 @@ def test_saving_pickable(tmpdir, metric: Metric):
# Check metric value is the same
assert results_before_save == results_after_load
def check_call_order():
class DummyMetric(Metric):
def __init__(self):
super().__init__("dummy")
self.call_history = ["init"]
@staticmethod
def input_convert(self, data: Any):
self.call_history.append("input_convert")
return super(DummyMetric, self).input_convert(self, data)
def forward(self, tensor1, tensor2):
self.call_history.append("forward")
return tensor1 - tensor2
@staticmethod
def output_convert(self, data: Any, output: Any):
self.call_history.append("output_convert")
return super(DummyMetric, self).output_convert(self, data, output)
def ddp_sync(self, tensor: Any):
self.call_history.append("ddp_sync")
return super().ddp_sync(tensor)
@staticmethod
def ddp_reduce(self, data: Any, output: Any):
self.call_history.append("ddp_reduce")
return super(DummyMetric, self).ddp_reduce(self, data, output)
def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor:
self.call_history.append("aggregate")
return super().aggregate(*tensors)
def reset(self):
self.call_history.append("reset")
return super().reset()
@property
def aggregated(self) -> torch.Tensor:
self.call_history.append("aggregated")
return super().aggregated
@staticmethod
def compute(self, data: Any, output: Any):
self.call_history.append("compute")
return super(DummyMetric, self).compute(self, data, output)
metric = DummyMetric()
assert metric.call_history == ["init"]
result = metric(torch.tensor([2.0]), torch.tensor([1.0]))
assert torch.allclose(result, torch.tensor(1.0))
assert metric.call_history == [
"init",
"input_convert",
"forward",
"output_convert",
"ddp_reduce",
"ddp_sync",
"aggregate",
]
aggr = metric.aggregated
assert metric.call_history == [
"init",
"input_convert",
"forward",
"output_convert",
"ddp_reduce",
"ddp_sync",
"aggregate",
"aggregated",
"aggregate",
"reset",
]
assert torch.allclose(aggr, result)
_ = metric(torch.tensor(2.0), torch.tensor(1.0))
assert metric.call_history == [
"init",
"input_convert",
"forward",
"output_convert",
"ddp_reduce",
"ddp_sync",
"aggregate",
"aggregated",
"aggregate",
"reset",
"input_convert",
"forward",
"output_convert",
"ddp_reduce",
"ddp_sync",
"aggregate",
]
metric = DummyMetric()
_ = metric(torch.tensor([2.0]), torch.tensor([1.0]))
_ = metric(torch.tensor([3.0]), torch.tensor([0.0]))
aggregated = metric.aggregated
assert torch.allclose(aggregated, torch.tensor(2.0))
assert metric.call_history == [
"init",
"input_convert",
"forward",
"output_convert",
"ddp_reduce",
"ddp_sync",
"aggregate",
"input_convert",
"forward",
"output_convert",
"ddp_reduce",
"ddp_sync",
"aggregate",
"aggregated",
"aggregate",
"reset",
]

View File

@ -167,13 +167,12 @@ def test_sklearn_metric(metric_class, sklearn_func, inputs):
sklearn_result = sklearn_func(**numpy_inputs)
lightning_result = metric_class(**inputs)
assert np.allclose(sklearn_result, lightning_result, atol=1e-5)
sklearn_result = apply_to_collection(
sklearn_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)
lightning_result = apply_to_collection(
lightning_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)
lightning_result = np.array(apply_to_collection(
lightning_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy))
assert np.allclose(sklearn_result, lightning_result, atol=1e-5)
assert isinstance(lightning_result, type(sklearn_result))