revamp entire metrics (#3868)

* removed metric

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* added new metrics

Co-authored-by: Teddy Koker teddy.koker@gmail.com

* pep8

Co-authored-by: Teddy Koker teddy.koker@gmail.com

* pep8

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* docs

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* docs

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* win ddp tests skip

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* win ddp tests skip

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* win ddp tests skip

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* win ddp tests skip

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* reset in compute, cache compute

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* reduce_ops handling

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* sync -> sync_dist, type annotations

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* wip docs

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* mean squared error

* docstring

* added mean ___ error metrics

* added mean ___ error metrics

* seperated files

* accuracy doctest

* gpu fix

* remove unnecessary mixin

* metric and accuracy docstring

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* metric docs

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* pep8, changelog

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>

* refactor dist utils, pep8

* refactor dist utils, pep8

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
This commit is contained in:
Ananya Harsh Jha 2020-10-06 17:03:24 -04:00 committed by GitHub
parent 4722cc0bf0
commit f76bc5254e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
45 changed files with 1102 additions and 8082 deletions

View File

@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added new Metrics API. ([#3868](https://github.com/PyTorchLightning/pytorch-lightning/pull/3868))
- Enable PyTorch 1.7 compatibility ([#3541](https://github.com/PyTorchLightning/pytorch-lightning/pull/3541))
- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528))
@ -63,6 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Removed
- Remove old Metrics API. ([#3868](https://github.com/PyTorchLightning/pytorch-lightning/pull/3868))
### Fixed

View File

@ -3,124 +3,124 @@
import torch
from torch.nn import Module
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.metrics import TensorMetric, NumpyMetric
from pytorch_lightning.metrics import Metric
.. _metrics:
Metrics
=======
This is a general package for PyTorch Metrics. These can also be used with regular non-lightning PyTorch code.
Metrics are used to monitor model performance.
In this package, we provide two major pieces of functionality.
``pytorch_lightning.metrics`` is a Metrics API created for easy metric development and usage in
PyTorch and PyTorch Lightning. It is rigorously tested for all edge cases and includes a growing list of
common metric implementations.
1. A Metric class you can use to implement metrics with built-in distributed (ddp) support which are device agnostic.
2. A collection of ready to use popular metrics. There are two types of metrics: Class metrics and Functional metrics.
3. An interface to call `sklearns metrics <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_
The metrics API provides ``update()``, ``compute()``, ``reset()`` functions to the user. The metric base class inherits
``nn.Module`` which allows us to call ``metric(...)`` directly. The ``forward()`` method of the base ``Metric`` class
serves the dual purpose of calling ``update()`` on its input and simultanously returning the value of the metric over the
provided input.
Example::
These metrics work with DDP in PyTorch and PyTorch Lightning by default. When ``.compute()`` is called in
distributed mode, the internal state of each metric is synced and reduced across each process, so that the
logic present in ``.compute()`` is applied to state information from all processes.
from pytorch_lightning.metrics.functional import accuracy
The example below shows how to use a metric in your ``LightningModule``:
pred = torch.tensor([0, 1, 2, 3])
target = torch.tensor([0, 1, 2, 2])
.. note::
# calculates accuracy across all GPUs and all Nodes used in training
accuracy(pred, target)
.. warning::
The metrics package is still in development! If we're missing a metric or you find a mistake, please send a PR!
to a few metrics. Please feel free to create an issue/PR if you have a proposed metric or have found a bug.
----------------
Implement a metric
------------------
You can implement metrics as either a PyTorch metric or a Numpy metric (It is recommended to use PyTorch metrics when possible,
since Numpy metrics slow down training).
Use :class:`TensorMetric` to implement native PyTorch metrics. This class
handles automated DDP syncing and converts all inputs and outputs to tensors.
Use :class:`NumpyMetric` to implement numpy metrics. This class
handles automated DDP syncing and converts all inputs and outputs to tensors.
.. warning::
Numpy metrics might slow down your training substantially,
since every metric computation requires a GPU sync to convert tensors to numpy.
----------------
TensorMetric
^^^^^^^^^^^^
Here's an example showing how to implement a TensorMetric
.. testcode::
class RMSE(TensorMetric):
def forward(self, x, y):
return torch.sqrt(torch.mean(torch.pow(x-y, 2.0)))
.. autoclass:: pytorch_lightning.metrics.metric.TensorMetric
:noindex:
----------------
NumpyMetric
^^^^^^^^^^^
Here's an example showing how to implement a NumpyMetric
.. testcode::
class RMSE(NumpyMetric):
def forward(self, x, y):
return np.sqrt(np.mean(np.power(x-y, 2.0)))
.. autoclass:: pytorch_lightning.metrics.metric.NumpyMetric
:noindex:
----------------
Class Metrics
-------------
Class metrics can be instantiated as part of a module definition (even with just
plain PyTorch).
.. testcode::
from pytorch_lightning.metrics import Accuracy
# Plain PyTorch
class MyModule(Module):
def __init__(self):
super().__init__()
self.metric = Accuracy()
def forward(self, x, y):
y_hat = ...
acc = self.metric(y_hat, y)
# PyTorch Lightning
class MyModule(LightningModule):
def __init__(self):
super().__init__()
self.metric = Accuracy()
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = ...
acc = self.metric(y_hat, y)
These metrics even work when using distributed training:
For v0.10.0 the user is expected to call ``.compute()`` on the metric at the end each epoch.
This has been shown in the example below. For v1.0 release, we will integrate metrics
with logging and ``.compute()`` will be called automatically by PyTorch Lightning.
.. code-block:: python
model = MyModule()
trainer = Trainer(gpus=8, num_nodes=2)
def __init__(self):
...
self.accuracy = pl.metrics.Accuracy()
def training_step(self, batch, batch_idx):
logits = self(x)
...
# log step metric
self.log('train_acc_step', self.accuracy(logits, y))
...
def training_epoch_end(self, outs):
# log epoch metric
self.log('train_acc_epoch', self.accuracy.compute())
# any metric automatically reduces across GPUs (even the ones you implement using Lightning)
trainer.fit(model)
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:
.. code-block:: python
from pytorch_lightning import metrics
train_accuracy = metrics.Accuracy()
valid_accuracy = metrics.Accuracy(compute_on_step=False)
for epoch in range(epochs):
for x, y in train_data:
y_hat = model(x)
# training step accuracy
batch_acc = train_accuracy(y_hat, y)
for x, y in valid_data:
y_hat = model(x)
valid_accuracy(y_hat, y)
# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()
# total accuracy over all validation batches
total_valid_accuracy = train_accuracy.compute()
Implementing a Metric
---------------------
To implement your custom metric, subclass the base ``Metric`` class and implement the following methods:
- ``__init__()``: Each state variable should be called using ``self.add_state(...)``.
- ``update()``: Any code needed to update the state given any inputs to the metric.
- ``compute()``: Computes a final value from the state of the metric.
All you need to do is call add_state correctly to implement a custom metric with DDP.
``reset()`` is called on metric state variables added using ``add_state()``.
To see how metric states are synchronized across distributed processes, refer to ``add_state()`` docs
from the base ``Metric`` class.
Example implementation:
.. code-block:: python
from pytorch_lightning.metrics import Metric
class MyAccuracy(Metric):
def __init__(self, ddp_sync_on_step=False):
super().__init__(ddp_sync_on_step=ddp_sync_on_step)
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
preds, target = self._input_format(preds, target)
assert preds.shape == target.shape
self.correct += torch.sum(preds == target)
self.total += target.numel()
def compute(self):
return self.correct.float() / self.total
Metric
^^^^^^
.. autoclass:: pytorch_lightning.metrics.Metric
:noindex:
Classification Metrics
----------------------
Accuracy
^^^^^^^^
@ -128,510 +128,25 @@ Accuracy
.. autoclass:: pytorch_lightning.metrics.classification.Accuracy
:noindex:
AveragePrecision
Regression Metrics
------------------
MeanSquaredError
^^^^^^^^^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.AveragePrecision
.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError
:noindex:
AUROC
^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.AUROC
MeanAbsoluteError
^^^^^^^^^^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError
:noindex:
BLEUScore
^^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.nlp.BLEUScore
:noindex:
ConfusionMatrix
^^^^^^^^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix
:noindex:
DiceCoefficient
^^^^^^^^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.DiceCoefficient
:noindex:
EmbeddingSimilarity
MeanSquaredLogError
^^^^^^^^^^^^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.self_supervised.EmbeddingSimilarity
:noindex:
F1
^^
.. autoclass:: pytorch_lightning.metrics.classification.F1
:noindex:
FBeta
^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.FBeta
:noindex:
PrecisionRecallCurve
^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecallCurve
:noindex:
Precision
^^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.Precision
:noindex:
Recall
^^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.Recall
:noindex:
ROC
^^^
.. autoclass:: pytorch_lightning.metrics.classification.ROC
:noindex:
MAE
^^^
.. autoclass:: pytorch_lightning.metrics.regression.MAE
:noindex:
MSE
^^^
.. autoclass:: pytorch_lightning.metrics.regression.MSE
:noindex:
MulticlassROC
^^^^^^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.MulticlassROC
:noindex:
MulticlassPrecisionRecallCurve
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pytorch_lightning.metrics.classification.MulticlassPrecisionRecallCurve
:noindex:
IoU
^^^
.. autoclass:: pytorch_lightning.metrics.classification.IoU
:noindex:
RMSE
^^^^
.. autoclass:: pytorch_lightning.metrics.regression.RMSE
:noindex:
RMSLE
^^^^^
.. autoclass:: pytorch_lightning.metrics.regression.RMSLE
:noindex:
SSIM
^^^^
.. autoclass:: pytorch_lightning.metrics.regression.SSIM
:noindex:
----------------
Functional Metrics
------------------
Functional metrics can be called anywhere (even used with just plain PyTorch).
.. code-block:: python
from pytorch_lightning.metrics.functional import accuracy
pred = torch.tensor([0, 1, 2, 3])
target = torch.tensor([0, 1, 2, 2])
# calculates accuracy across all GPUs and all Nodes used in training
accuracy(pred, target)
These metrics even work when using distributed training:
.. code-block:: python
class MyModule(...):
def forward(self, x, y):
return accuracy(x, y)
model = MyModule()
trainer = Trainer(gpus=8, num_nodes=2)
# any metric automatically reduces across GPUs (even the ones you implement using Lightning)
trainer.fit(model)
accuracy (F)
^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.accuracy
:noindex:
auc (F)
^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.auc
:noindex:
auroc (F)
^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.auroc
:noindex:
average_precision (F)
^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.average_precision
:noindex:
bleu_score (F)
^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.bleu_score
:noindex:
confusion_matrix (F)
^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.confusion_matrix
:noindex:
dice_score (F)
^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.dice_score
:noindex:
embedding_similarity (F)
^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.embedding_similarity
:noindex:
f1_score (F)
^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.f1_score
:noindex:
fbeta_score (F)
^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.fbeta_score
:noindex:
multiclass_precision_recall_curve (F)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.multiclass_precision_recall_curve
:noindex:
multiclass_roc (F)
^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.multiclass_roc
:noindex:
precision (F)
^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.precision
:noindex:
precision_recall (F)
^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.precision_recall
:noindex:
precision_recall_curve (F)
^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.precision_recall_curve
:noindex:
recall (F)
^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.recall
:noindex:
roc (F)
^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.roc
:noindex:
stat_scores (F)
^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.stat_scores
:noindex:
iou (F)
^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.iou
:noindex:
mse (F)
^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.mse
:noindex:
rmse (F)
^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.rmse
:noindex:
mae (F)
^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.mae
:noindex:
rmsle (F)
^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.rmsle
:noindex:
psnr (F)
^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.psnr
:noindex:
ssim (F)
^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.ssim
:noindex:
stat_scores_multiple_classes (F)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.stat_scores_multiple_classes
:noindex:
----------------
Metric pre-processing
---------------------
to_categorical (F)
^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.to_categorical
:noindex:
to_onehot (F)
^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.functional.to_onehot
:noindex:
----------------
Sklearn interface
-----------------
Lightning supports `sklearns metrics module <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_
as a backend for calculating metrics. Sklearns metrics are well tested and robust,
but requires conversion between pytorch and numpy thus may slow down your computations.
To use the sklearn backend of metrics simply import as
.. code-block:: python
import pytorch_lightning.metrics.sklearns import plm
metric = plm.Accuracy(normalize=True)
val = metric(pred, target)
Each converted sklearn metric comes has the same interface as its
original counterpart (e.g. accuracy takes the additional `normalize` keyword).
Like the native Lightning metrics, these converted sklearn metrics also come
with built-in distributed (ddp) support.
SklearnMetric (sk)
^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.SklearnMetric
:noindex:
Accuracy (sk)
^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.Accuracy
:noindex:
AUC (sk)
^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.AUC
:noindex:
AveragePrecision (sk)
^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.AveragePrecision
:noindex:
BalancedAccuracy (sk)
^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.BalancedAccuracy
:noindex:
CohenKappaScore (sk)
^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.CohenKappaScore
:noindex:
ConfusionMatrix (sk)
^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.ConfusionMatrix
:noindex:
DCG (sk)
^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.DCG
:noindex:
F1 (sk)
^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.F1
:noindex:
FBeta (sk)
^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.FBeta
:noindex:
Hamming (sk)
^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.Hamming
:noindex:
Hinge (sk)
^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.Hinge
:noindex:
Jaccard (sk)
^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.Jaccard
:noindex:
Precision (sk)
^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.Precision
:noindex:
Recall (sk)
^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.Recall
:noindex:
PrecisionRecallCurve (sk)
^^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.PrecisionRecallCurve
:noindex:
ROC (sk)
^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.ROC
:noindex:
AUROC (sk)
^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.AUROC
:noindex:
ExplainedVariance (sk)
^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.ExplainedVariance
:noindex:
MeanAbsoluteError (sk)
^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanAbsoluteError
:noindex:
MeanSquaredError (sk)
^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanSquaredError
:noindex:
MeanSquaredLogError (sk)
^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanSquaredLogError
:noindex:
MedianAbsoluteError (sk)
^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.MedianAbsoluteError
:noindex:
R2Score (sk)
^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.R2Score
:noindex:
MeanPoissonDeviance (sk)
^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanPoissonDeviance
:noindex:
MeanGammaDeviance (sk)
^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanGammaDeviance
:noindex:
MeanTweedieDeviance (sk)
^^^^^^^^^^^^^^^^^^^^^^^^
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanTweedieDeviance
.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError
:noindex:

View File

@ -14,14 +14,13 @@
import numbers
from copy import copy
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any, List, Tuple
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any, List, Tuple, Iterable
import torch
from torch import Tensor
import os
from pytorch_lightning.metrics.converters import sync_ddp_if_available
from typing import Iterable
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
class Result(Dict):

View File

@ -1,66 +1,4 @@
from pytorch_lightning.metrics.classification import (
Accuracy,
AveragePrecision,
ConfusionMatrix,
F1,
FBeta,
Recall,
ROC,
AUROC,
DiceCoefficient,
MulticlassPrecisionRecallCurve,
MulticlassROC,
Precision,
PrecisionRecallCurve,
IoU,
)
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
from pytorch_lightning.metrics.nlp import BLEUScore
from pytorch_lightning.metrics.self_supervised import EmbeddingSimilarity
from pytorch_lightning.metrics.regression import (
MAE,
MSE,
PSNR,
RMSE,
RMSLE,
SSIM
)
from pytorch_lightning.metrics.sklearns import (
AUC,
SklearnMetric,
)
from pytorch_lightning.metrics.metric import Metric
__classification_metrics = [
"AUC",
"AUROC",
"Accuracy",
"AveragePrecision",
"ConfusionMatrix",
"DiceCoefficient",
"F1",
"FBeta",
"MulticlassPrecisionRecallCurve",
"MulticlassROC",
"Precision",
"PrecisionRecallCurve",
"ROC",
"Recall",
"IoU",
]
__regression_metrics = [
"MAE",
"MSE",
"PSNR",
"RMSE",
"RMSLE",
"SSIM"
]
__sequence_metrics = ["BLEUScore"]
__selfsuper_metrics = ["EmbeddingSimilarity"]
__all__ = __regression_metrics \
+ __classification_metrics \
+ __selfsuper_metrics \
+ __sequence_metrics \
+ ["SklearnMetric"]
from pytorch_lightning.metrics.classification.accuracy import Accuracy
from pytorch_lightning.metrics.regression import MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError

View File

@ -1,866 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence, Tuple
import torch
from pytorch_lightning.metrics.functional.classification import (
accuracy,
auroc,
average_precision,
confusion_matrix,
_confmat_normalize,
dice_score,
f1_score,
fbeta_score,
iou,
multiclass_precision_recall_curve,
multiclass_roc,
precision_recall_curve,
roc,
precision_recall
)
from pytorch_lightning.metrics.functional.reduction import class_reduce
from pytorch_lightning.metrics.metric import TensorMetric
class Accuracy(TensorMetric):
"""
Computes the accuracy classification score
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = Accuracy()
>>> metric(pred, target)
tensor(0.7500)
"""
def __init__(
self,
num_classes: Optional[int] = None,
class_reduction: str = 'micro',
reduce_group: Any = None,
):
"""
Args:
num_classes: number of classes
class_reduction: method to reduce metric score over labels
- ``'micro'``: calculate metrics globally (default)
- ``'macro'``: calculate metrics for each label, and find their unweighted mean.
- ``'weighted'``: calculate metrics for each label, and find their weighted mean.
- ``'none'``: returns calculated metric per class
reduce_group: the process group to reduce metric results from DDP
"""
super().__init__(name="accuracy", reduce_group=reduce_group)
self.num_classes = num_classes
assert class_reduction in ('micro', 'macro', 'weighted', 'none')
self.class_reduction = class_reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the classification score.
"""
return accuracy(pred=pred, target=target,
num_classes=self.num_classes,
class_reduction='none',
return_state=True)
@staticmethod
def compute(self, data: Any, output: Any):
tps, sups = output['tps'], output['sups']
return class_reduce(tps, sups, sups, class_reduction=self.class_reduction)
class ConfusionMatrix(TensorMetric):
"""
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
in group i that were predicted in group j.
Example:
>>> pred = torch.tensor([0, 1, 2, 2])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = ConfusionMatrix()
>>> metric(pred, target)
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 2.]])
"""
def __init__(
self,
num_classes: Optional[int] = None,
normalize: bool = False,
reduce_group: Any = None,
):
"""
Args:
num_classes: number of classes
normalize: whether to compute a normalized confusion matrix
reduce_group: the process group to reduce metric results from DDP
"""
super().__init__(
name="confusion_matrix",
reduce_group=reduce_group,
)
self.normalize = normalize
self.num_classes = num_classes
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the confusion matrix.
"""
return confusion_matrix(pred=pred, target=target,
normalize=False, # we normalize after ddp sync
num_classes=self.num_classes)
@staticmethod
def compute(self, data: Any, output: Any):
""" Confusion matrix normalization needs to happen after ddp sync """
confmat = output
if self.normalize:
confmat = _confmat_normalize(confmat)
return confmat
class PrecisionRecallCurve(TensorMetric):
"""
Computes the precision recall curve
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = PrecisionRecallCurve()
>>> prec, recall, thr = metric(pred, target)
>>> prec
tensor([0.3333, 0.0000, 0.0000, 1.0000])
>>> recall
tensor([1., 0., 0., 0.])
>>> thr
tensor([1., 2., 3.])
"""
def __init__(
self,
pos_label: int = 1,
reduce_group: Any = None,
):
"""
Args:
pos_label: positive label indicator
reduce_group: the process group to reduce metric results from DDP
"""
super().__init__(
name="precision_recall_curve",
reduce_group=reduce_group,
)
self.pos_label = pos_label
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
sample_weight: the weights per sample
Return:
- precision values
- recall values
- threshold values
"""
return precision_recall_curve(pred=pred, target=target,
sample_weight=sample_weight, pos_label=self.pos_label)
class Precision(TensorMetric):
"""
Computes the precision score
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = Precision(num_classes=4, class_reduction='macro')
>>> metric(pred, target)
tensor(0.7500)
"""
def __init__(
self,
num_classes: Optional[int] = None,
class_reduction: str = 'micro',
reduce_group: Any = None,
):
"""
Args:
num_classes: number of classes
class_reduction: method to reduce metric score over labels
- ``'micro'``: calculate metrics globally (default)
- ``'macro'``: calculate metrics for each label, and find their unweighted mean.
- ``'weighted'``: calculate metrics for each label, and find their weighted mean.
- ``'none'``: returns calculated metric per class
reduce_group: the process group to reduce metric results from DDP
"""
super().__init__(
name="precision",
reduce_group=reduce_group,
)
self.num_classes = num_classes
assert class_reduction in ('micro', 'macro', 'weighted', 'none')
self.class_reduction = class_reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the classification score.
"""
return precision_recall(pred=pred, target=target,
num_classes=self.num_classes,
class_reduction='none',
return_state=True)
@staticmethod
def compute(self, data: Any, output: Any):
tps, fps, sups = output['tps'], output['fps'], output['sups']
return class_reduce(tps, tps + fps, sups, class_reduction=self.class_reduction)
class Recall(TensorMetric):
"""
Computes the recall score
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = Recall()
>>> metric(pred, target)
tensor(0.7500)
"""
def __init__(
self,
num_classes: Optional[int] = None,
class_reduction: str = 'micro',
reduce_group: Any = None,
):
"""
Args:
num_classes: number of classes
class_reduction: method to reduce metric score over labels
- ``'micro'``: calculate metrics globally (default)
- ``'macro'``: calculate metrics for each label, and find their unweighted mean.
- ``'weighted'``: calculate metrics for each label, and find their weighted mean.
- ``'none'``: returns calculated metric per class
reduce_group: the process group to reduce metric results from DDP
"""
super().__init__(
name="recall",
reduce_group=reduce_group,
)
self.num_classes = num_classes
assert class_reduction in ('micro', 'macro', 'weighted', 'none')
self.class_reduction = class_reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the classification score.
"""
return precision_recall(pred=pred, target=target,
num_classes=self.num_classes,
class_reduction='none',
return_state=True)
@staticmethod
def compute(self, data: Any, output: Any):
tps, fns, sups = output['tps'], output['fns'], output['sups']
return class_reduce(tps, tps + fns, sups, class_reduction=self.class_reduction)
class AveragePrecision(TensorMetric):
"""
Computes the average precision score
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = AveragePrecision()
>>> metric(pred, target)
tensor(0.3333)
"""
def __init__(
self,
pos_label: int = 1,
reduce_group: Any = None,
):
"""
Args:
pos_label: positive label indicator
reduce_group: the process group to reduce metric results from DDP
"""
super().__init__(
name="AP",
reduce_group=reduce_group,
)
self.pos_label = pos_label
def forward(
self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None
) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
sample_weight: the weights per sample
Return:
torch.Tensor: classification score
"""
return average_precision(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label)
class AUROC(TensorMetric):
"""
Computes the area under curve (AUC) of the receiver operator characteristic (ROC)
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 1, 0])
>>> metric = AUROC()
>>> metric(pred, target)
tensor(0.5000)
"""
def __init__(
self,
pos_label: int = 1,
reduce_group: Any = None,
):
"""
Args:
pos_label: positive label indicator
reduce_group: the process group to reduce metric results from DDP
"""
super().__init__(
name="auroc",
reduce_group=reduce_group,
)
self.pos_label = pos_label
def forward(
self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None
) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
sample_weight: the weights per sample
Return:
torch.Tensor: classification score
"""
return auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label)
class FBeta(TensorMetric):
"""
Computes the FBeta Score, which is the weighted harmonic mean of precision and recall.
It ranges between 1 and 0, where 1 is perfect and the worst value is 0.
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = FBeta(0.25, class_reduction='macro')
>>> metric(pred, target)
tensor(0.7361)
"""
def __init__(
self,
beta: float,
num_classes: Optional[int] = None,
class_reduction: str = 'micro',
reduce_group: Any = None,
):
"""
Args:
beta: determines the weight of recall in the combined score.
num_classes: number of classes
class_reduction: method to reduce metric score over labels
- ``'micro'``: calculate metrics globally (default)
- ``'macro'``: calculate metrics for each label, and find their unweighted mean.
- ``'weighted'``: calculate metrics for each label, and find their weighted mean.
- ``'none'``: returns calculated metric per class
reduce_group: the process group to reduce metric results from DDP
"""
super().__init__(
name="fbeta",
reduce_group=reduce_group,
)
self.beta = beta
self.num_classes = num_classes
assert class_reduction in ('micro', 'macro', 'weighted', 'none')
self.class_reduction = class_reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
Return:
torch.Tensor: classification score
"""
return precision_recall(pred=pred, target=target,
num_classes=self.num_classes,
class_reduction='none',
return_state=True)
@staticmethod
def compute(self, data: Any, output: Any):
""" tps, fps, fns, sups needs to be synced before we do any calculations """
tps, fps, fns, sups = output['tps'], output['fps'], output['fns'], output['sups']
intermidiate_reduction = 'none' if self.class_reduction != "micro" else 'micro'
precision = class_reduce(tps, tps + fps, sups, class_reduction=intermidiate_reduction)
recall = class_reduce(tps, tps + fns, sups, class_reduction=intermidiate_reduction)
num = (1 + self.beta ** 2) * precision * recall
denom = ((self.beta ** 2) * precision + recall)
if intermidiate_reduction == 'micro':
return torch.sum(num) / torch.sum(denom)
return class_reduce(num, denom, sups, class_reduction=self.class_reduction)
class F1(FBeta):
"""
Computes the F1 score, which is the harmonic mean of the precision and recall.
It ranges between 1 and 0, where 1 is perfect and the worst value is 0.
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = F1(class_reduction='macro')
>>> metric(pred, target)
tensor(0.6667)
"""
def __init__(
self,
num_classes: Optional[int] = None,
class_reduction: str = 'micro',
reduce_group: Any = None,
):
"""
Args:
num_classes: number of classes
class_reduction: method to reduce metric score over labels
- ``'micro'``: calculate metrics globally (default)
- ``'macro'``: calculate metrics for each label, and find their unweighted mean.
- ``'weighted'``: calculate metrics for each label, and find their weighted mean.
- ``'none'``: returns calculated metric per class
reduce_group: the process group to reduce metric results from DDP
"""
super().__init__(beta=1.0,
num_classes=num_classes,
class_reduction=class_reduction,
reduce_group=reduce_group)
self.name = "f1"
class ROC(TensorMetric):
"""
Computes the Receiver Operator Characteristic (ROC)
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = ROC()
>>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE
(tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]),
tensor([0., 0., 0., 1., 1.]),
tensor([4., 3., 2., 1., 0.]))
"""
def __init__(
self,
pos_label: int = 1,
reduce_group: Any = None,
):
"""
Args:
pos_label: positive label indicator
reduce_group: the process group to reduce metric results from DDP
"""
super().__init__(
name="roc",
reduce_group=reduce_group,
)
self.pos_label = pos_label
def forward(
self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Actual metric computation
Args:
pred: predicted labels
target: groundtruth labels
sample_weight: the weights per sample
Return:
- false positive rate
- true positive rate
- thresholds
"""
return roc(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label)
class MulticlassROC(TensorMetric):
"""
Computes the multiclass ROC
Example:
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> metric = MulticlassROC()
>>> classes_roc = metric(pred, target)
>>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE
((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])))
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduce_group: Any = None,
):
"""
Args:
num_classes: number of classes
reduce_group: the process group to reduce metric results from DDP
"""
super().__init__(
name="multiclass_roc",
reduce_group=reduce_group,
)
self.num_classes = num_classes
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Actual metric computation
Args:
pred: predicted probability for each label
target: groundtruth labels
sample_weight: Weights for each sample defining the sample's impact on the score
Return:
tuple: A tuple consisting of one tuple per class, holding false positive rate, true positive rate and thresholds
"""
return multiclass_roc(pred=pred, target=target, sample_weight=sample_weight, num_classes=self.num_classes)
def aggregate(self, *tensors: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Aggregates results by stacking them instead of concatenating before averaging.
Returns:
the aggregated results
"""
return tuple([tuple([torch.stack(tmps).mean(0) for tmps in zip(*_tensors)]) for _tensors in zip(*tensors)])
class MulticlassPrecisionRecallCurve(TensorMetric):
"""Computes the multiclass PR Curve
Example:
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> metric = MulticlassPrecisionRecallCurve()
>>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE
((tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])),
(tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])),
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])),
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])))
"""
def __init__(
self,
num_classes: Optional[int] = None,
reduce_group: Any = None,
):
"""
Args:
num_classes: number of classes
reduce_group: the process group to reduce metric results from DDP
"""
super().__init__(
name="multiclass_precision_recall_curve",
reduce_group=reduce_group,
)
self.num_classes = num_classes
def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Actual metric computation
Args:
pred: predicted probability for each label
target: groundtruth labels
sample_weight: Weights for each sample defining the sample's impact on the score
Return:
tuple: A tuple consisting of one tuple per class, holding precision, recall and thresholds
"""
return multiclass_precision_recall_curve(
pred=pred, target=target, sample_weight=sample_weight, num_classes=self.num_classes
)
def aggregate(self, *tensors: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Aggregates results by stacking them instead of concatenating before averaging.
Returns:
the aggregated results
"""
return tuple([tuple([torch.stack(tmps).mean(0) for tmps in zip(*_tensors)]) for _tensors in zip(*tensors)])
class DiceCoefficient(TensorMetric):
"""
Computes the dice coefficient
Example:
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> metric = DiceCoefficient()
>>> metric(pred, target)
tensor(0.3333)
"""
def __init__(
self,
include_background: bool = False,
nan_score: float = 0.0,
no_fg_score: float = 0.0,
reduction: str = "elementwise_mean",
reduce_group: Any = None,
):
"""
Args:
include_background: whether to also compute dice for the background
nan_score: score to return, if a NaN occurs during computation (denom zero)
no_fg_score: score to return, if no foreground pixel was found in target
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
reduce_group: the process group to reduce metric results from DDP
"""
super().__init__(
name="dice",
reduce_group=reduce_group,
)
self.include_background = include_background
self.nan_score = nan_score
self.no_fg_score = no_fg_score
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted probability for each label
target: groundtruth labels
Return:
torch.Tensor: the calculated dice coefficient
"""
return dice_score(
pred=pred,
target=target,
bg=self.include_background,
nan_score=self.nan_score,
no_fg_score=self.no_fg_score,
reduction=self.reduction,
)
class IoU(TensorMetric):
"""
Computes the intersection over union.
Example:
>>> pred = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0],
... [0, 0, 1, 1, 1, 0, 0, 0],
... [0, 0, 0, 0, 0, 0, 0, 0]])
>>> target = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0],
... [0, 0, 0, 1, 1, 1, 0, 0],
... [0, 0, 0, 0, 0, 0, 0, 0]])
>>> metric = IoU()
>>> metric(pred, target)
tensor(0.7045)
"""
def __init__(
self,
ignore_index: Optional[int] = None,
absent_score: float = 0.0,
num_classes: Optional[int] = None,
reduction: str = "elementwise_mean",
):
"""
Args:
ignore_index: optional int specifying a target class to ignore. If given, this class index does not
contribute to the returned score, regardless of reduction method. Has no effect if given an int that is
not in the range [0, num_classes-1], where num_classes is either given or derived from pred and target.
By default, no index is ignored, and all classes are used.
absent_score: score to use for an individual class, if no instances of the class index were present in
`y_pred` AND no instances of the class index were present in `y_true`. For example, if we have 3
classes, [0, 0] for `y_pred`, and [0, 2] for `y_true`, then class 1 would be assigned the
`absent_score`. Default is 0.0.
num_classes: Optionally specify the number of classes
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
"""
super().__init__(name="iou")
self.ignore_index = ignore_index
self.absent_score = absent_score
self.num_classes = num_classes
self.reduction = reduction
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor, sample_weight: Optional[torch.Tensor] = None):
"""
Actual metric calculation.
"""
return iou(
pred=y_pred,
target=y_true,
ignore_index=self.ignore_index,
absent_score=self.absent_score,
num_classes=self.num_classes,
reduction=self.reduction,
)

View File

@ -0,0 +1 @@
from pytorch_lightning.metrics.classification.accuracy import Accuracy

View File

@ -0,0 +1,103 @@
import math
import functools
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Union
from collections.abc import Mapping, Sequence
from collections import namedtuple
import torch
from torch import nn
from pytorch_lightning.metrics.metric import Metric
class Accuracy(Metric):
"""
Computes accuracy. Works with binary, multiclass, and multilabel data.
Accepts logits from a model output or integer class values in prediction.
Works with multi-dimensional preds and target.
Forward accepts
- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
- ``target`` (long tensor): ``(N, ...)``
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
This is the case for binary and multi-label logits.
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
Args:
threshold:
Threshold value for binary or multi-label logits. default: 0.5
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
ddp_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
Example:
>>> from pytorch_lightning.metrics import Accuracy
>>> target = torch.tensor([0, 1, 2, 3])
>>> preds = torch.tensor([0, 2, 1, 3])
>>> accuracy = Accuracy()
>>> accuracy(preds, target)
tensor(0.5000)
"""
def __init__(
self,
threshold: float = 0.5,
compute_on_step: bool = True,
ddp_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
compute_on_step=compute_on_step,
ddp_sync_on_step=ddp_sync_on_step,
process_group=process_group,
)
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
self.threshold = threshold
def _input_format(self, preds: torch.Tensor, target: torch.Tensor):
if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1):
raise ValueError(
"preds and target must have same number of dimensions, or one additional dimension for preds"
)
if len(preds.shape) == len(target.shape) + 1:
# multi class probabilites
preds = torch.argmax(preds, dim=1)
if len(preds.shape) == len(target.shape) and preds.dtype == torch.float:
# binary or multilabel probablities
preds = (preds >= self.threshold).long()
return preds, target
def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.
Args:
preds: Predictions from model
target: Ground truth values
"""
preds, target = self._input_format(preds, target)
assert preds.shape == target.shape
self.correct += torch.sum(preds == target)
self.total += target.numel()
def compute(self):
"""
Computes accuracy over state.
"""
return self.correct.float() / self.total

View File

@ -1,410 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file provides functions and decorators for automated input and output
conversion to/from :class:`numpy.ndarray` and :class:`torch.Tensor` as well as utilities to
sync tensors between different processes in a DDP scenario, when needed.
"""
from functools import reduce
import numbers
from typing import Any, Callable, Optional, Union
import numpy as np
import torch
from torch.utils.data._utils.collate import np_str_obj_array_pattern
from pytorch_lightning.utilities.apply_func import apply_to_collection
if torch.distributed.is_available():
from torch.distributed import ReduceOp
else:
class ReduceOp:
SUM = None
def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
"""
Decorator function to apply a function to all inputs of a function.
Args:
func_to_apply: the function to apply to the inputs
*dec_args: positional arguments for the function to be applied
**dec_kwargs: keyword arguments for the function to be applied
Return:
the decorated function
"""
def decorator_fn(func_to_decorate):
# actual function applying the give function to inputs
def new_func(*args, **kwargs):
args = func_to_apply(args, *dec_args, **dec_kwargs)
kwargs = func_to_apply(kwargs, *dec_args, **dec_kwargs)
return func_to_decorate(*args, **kwargs)
return new_func
return decorator_fn
def _apply_to_outputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
"""
Decorator function to apply a function to all outputs of a function.
Args:
func_to_apply: the function to apply to the outputs
*dec_args: positional arguments for the function to be applied
**dec_kwargs: keyword arguments for the function to be applied
Return:
the decorated function
"""
def decorator_fn(function_to_decorate):
# actual function applying the give function to outputs
def new_func(*args, **kwargs):
result = function_to_decorate(*args, **kwargs)
return func_to_apply(result, *dec_args, **dec_kwargs)
return new_func
return decorator_fn
def convert_to_tensor(data: Any, dtype=None, device=None) -> Any:
"""
Maps all kind of collections and numbers to tensors.
Args:
data: the data to convert to tensor
dtype: data type to convert to
device: device to cast to
Return:
the converted data
"""
if isinstance(data, numbers.Number):
return torch.tensor([data], dtype=dtype, device=device)
# is not array of object
elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None:
return torch.from_numpy(data).to(device=device, dtype=dtype)
elif isinstance(data, torch.Tensor):
return data.to(device=device, dtype=dtype)
raise TypeError(f"The given type ('{type(data).__name__}') cannot be converted to a tensor!")
def convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
"""Convert all tensors and numpy arrays to numpy arrays.
Args:
data: the tensor or array to convert to numpy
Return:
the resulting numpy array
"""
if isinstance(data, torch.Tensor):
return data.cpu().detach().numpy()
elif isinstance(data, numbers.Number):
return np.array([data])
elif isinstance(data, np.ndarray):
return data
raise TypeError("The given type ('%s') cannot be converted to a numpy array!" % type(data).__name__)
def _numpy_metric_input_conversion(func_to_decorate: Callable) -> Callable:
"""
Decorator converting all inputs of a function to numpy
Args:
func_to_decorate: the function whose inputs shall be converted
Return:
Callable: the decorated function
"""
return _apply_to_inputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)(
func_to_decorate
)
def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable:
"""
Decorator converting all outputs of a function to tensors
Args:
func_to_decorate: the function whose outputs shall be converted
Return:
Callable: the decorated function
"""
return _apply_to_outputs(convert_to_tensor)(func_to_decorate)
def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable:
"""
Decorator handling the argument conversion for metrics working on numpy.
All inputs of the decorated function will be converted to numpy and all
outputs will be converted to tensors.
Args:
func_to_decorate: the function whose inputs and outputs shall be converted
Return:
the decorated function
"""
# applies collection conversion from tensor to numpy to all inputs
# we need to include numpy arrays here, since otherwise they will also be treated as sequences
func_convert_inputs = _numpy_metric_input_conversion(func_to_decorate)
# converts all inputs back to tensors (device doesn't matter here, since this is handled by BaseMetric)
func_convert_in_out = _tensor_metric_output_conversion(func_convert_inputs)
return func_convert_in_out
def _tensor_metric_input_conversion(func_to_decorate: Callable) -> Callable:
"""
Decorator converting all inputs of a function to tensors
Args:
func_to_decorate: the function whose inputs shall be converted
Return:
Callable: the decorated function
"""
return _apply_to_inputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(
func_to_decorate
)
def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> Callable:
"""
Decorator converting all numpy arrays and numbers occuring in the outputs of a function to tensors
Args:
func_to_decorate: the function whose outputs shall be converted
Return:
Callable: the decorated function
"""
return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(
func_to_decorate
)
def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
"""
Decorator Handling the argument conversion for metrics working on tensors.
All inputs and outputs of the decorated function will be converted to tensors
Args:
func_to_decorate: the function whose inputs and outputs shall be converted
Return:
the decorated function
"""
# converts all inputs to tensor if possible
# we need to include tensors here, since otherwise they will also be treated as sequences
func_convert_inputs = _tensor_metric_input_conversion(func_to_decorate)
# convert all outputs to tensor if possible
return _tensor_metric_output_conversion(func_convert_inputs)
def _tensor_collection_metric_conversion(func_to_decorate: Callable) -> Callable:
"""
Decorator Handling the argument conversion for metrics working on tensors.
All inputs of the decorated function and all numpy arrays and numbers in
it's outputs will be converted to tensors
Args:
func_to_decorate: the function whose inputs and outputs shall be converted
Return:
the decorated function
"""
# converts all inputs to tensor if possible
# we need to include tensors here, since otherwise they will also be treated as sequences
func_convert_inputs = _tensor_metric_input_conversion(func_to_decorate)
# convert all outputs to tensor if possible
return _tensor_collection_metric_output_conversion(func_convert_inputs)
def sync_ddp_if_available(
result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> torch.Tensor:
"""
Function to reduce the tensors from several ddp processes to one master process
Args:
result: the value to sync and reduce (typically tensor or number)
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum.
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
Return:
reduced value
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
divide_by_world_size = False
if group is None:
group = torch.distributed.group.WORLD
if reduce_op is None:
reduce_op = torch.distributed.ReduceOp.SUM
elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
reduce_op = torch.distributed.ReduceOp.SUM
divide_by_world_size = True
# sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False)
if divide_by_world_size:
result = result / torch.distributed.get_world_size(group)
return result
def at_least_1d(tensor: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
"""Makes sure the tensor is at least of 1d shape
Args:
tensor: the tensor or array to check the shape for
Returns:
the optionally reshaped tensor
"""
if tensor.shape == ():
tensor = tensor.reshape(1, )
return tensor
def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None):
"""
Function to gather all tensors from several ddp processes onto a list that
is broadcasted to all processes
Args:
result: the value to sync
group: the process group to gather results from. Defaults to all processes (world)
Return:
gathered_result: list with size equal to the process group where
gathered_result[i] corresponds to result tensor from process i
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
if group is None:
group = torch.distributed.group.WORLD
world_size = torch.distributed.get_world_size(group)
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
# sync and broadcast all
torch.distributed.barrier(group=group)
torch.distributed.all_gather(gathered_result, result, group)
result = gathered_result
return result
def sync_ddp(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
"""
This decorator syncs a functions outputs across different processes for DDP.
Args:
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum
Return:
the decorated function
"""
def decorator_fn(func_to_decorate):
return _apply_to_outputs(
apply_to_collection, torch.Tensor, sync_ddp_if_available, group=group, reduce_op=reduce_op
)(func_to_decorate)
return decorator_fn
def numpy_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on numpy arrays.
It handles the argument conversion and DDP reduction for metrics working on numpy.
All inputs of the decorated function will be converted to numpy and all
outputs will be converted to tensors.
In DDP Training all output tensors will be reduced according to the given rules.
Args:
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum
Return:
the decorated function
"""
def decorator_fn(func_to_decorate):
return sync_ddp(group=group, reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate))
return decorator_fn
def tensor_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on tensors.
It handles the argument conversion and DDP reduction for metrics working on tensors.
All inputs and outputs of the decorated function will be converted to tensors.
In DDP Training all output tensors will be reduced according to the given rules.
Args:
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum
Return:
the decorated function
"""
def decorator_fn(func_to_decorate):
return sync_ddp(group=group, reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate))
return decorator_fn
def tensor_collection_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
"""
This decorator shall be used on all function metrics working on tensors and returning collections
that cannot be converted to tensors.
It handles the argument conversion and DDP reduction for metrics working on tensors.
All inputs and outputs of the decorated function will be converted to tensors.
In DDP Training all output tensors will be reduced according to the given rules.
Args:
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum
Return:
the decorated function
"""
def decorator_fn(func_to_decorate):
return sync_ddp(group=group, reduce_op=reduce_op)(_tensor_collection_metric_conversion(func_to_decorate))
return decorator_fn

View File

@ -1,34 +0,0 @@
from pytorch_lightning.metrics.functional.classification import (
accuracy,
auc,
auroc,
average_precision,
confusion_matrix,
dice_score,
f1_score,
fbeta_score,
multiclass_precision_recall_curve,
multiclass_roc,
precision,
precision_recall,
precision_recall_curve,
recall,
roc,
stat_scores,
stat_scores_multiple_classes,
to_categorical,
to_onehot,
iou,
)
from pytorch_lightning.metrics.functional.nlp import bleu_score
from pytorch_lightning.metrics.functional.regression import (
mae,
mse,
psnr,
rmse,
rmsle,
ssim
)
from pytorch_lightning.metrics.functional.self_supervised import (
embedding_similarity
)

File diff suppressed because it is too large Load Diff

View File

@ -1,103 +0,0 @@
# referenced from
# Library Name: torchtext
# Authors: torchtext authors and @sluks
# Date: 2020-07-18
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
from collections import Counter
from typing import List, Sequence
import torch
def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:
"""
Counting how many times each word appears in a given text with ngram
Args:
ngram_input_list: A list of translated text or reference texts
n_gram: gram value ranged 1 to 4
Return:
ngram_counter: a collections.Counter object of ngram
"""
ngram_counter = Counter()
for i in range(1, n_gram + 1):
for j in range(len(ngram_input_list) - i + 1):
ngram_key = tuple(ngram_input_list[j:(i + j)])
ngram_counter[ngram_key] += 1
return ngram_counter
def bleu_score(
translate_corpus: Sequence[str],
reference_corpus: Sequence[str],
n_gram: int = 4,
smooth: bool = False
) -> torch.Tensor:
"""
Calculate BLEU score of machine translated text with one or more references
Args:
translate_corpus: An iterable of machine translated corpus
reference_corpus: An iterable of iterables of reference corpus
n_gram: Gram value ranged from 1 to 4 (Default 4)
smooth: Whether or not to apply smoothing Lin et al. 2004
Return:
Tensor with BLEU Score
Example:
>>> translate_corpus = ['the cat is on the mat'.split()]
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
>>> bleu_score(translate_corpus, reference_corpus)
tensor(0.7598)
"""
assert len(translate_corpus) == len(reference_corpus)
numerator = torch.zeros(n_gram)
denominator = torch.zeros(n_gram)
precision_scores = torch.zeros(n_gram)
c = 0.0
r = 0.0
for (translation, references) in zip(translate_corpus, reference_corpus):
c += len(translation)
ref_len_list = [len(ref) for ref in references]
ref_len_diff = [abs(len(translation) - x) for x in ref_len_list]
r += ref_len_list[ref_len_diff.index(min(ref_len_diff))]
translation_counter = _count_ngram(translation, n_gram)
reference_counter = Counter()
for ref in references:
reference_counter |= _count_ngram(ref, n_gram)
ngram_counter_clip = translation_counter & reference_counter
for counter_clip in ngram_counter_clip:
numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip]
for counter in translation_counter:
denominator[len(counter) - 1] += translation_counter[counter]
trans_len = torch.tensor(c)
ref_len = torch.tensor(r)
if min(numerator) == 0.0:
return torch.tensor(0.0)
if smooth:
precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram))
else:
precision_scores = numerator / denominator
log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores)
geometric_mean = torch.exp(torch.sum(log_precision_scores))
brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len))
bleu = brevity_penalty * geometric_mean
return bleu

View File

@ -1,65 +0,0 @@
import torch
def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
"""
Reduces a given tensor by a given reduction method
Args:
to_reduce : the tensor, which shall be reduced
reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum')
Return:
reduced Tensor
Raise:
ValueError if an invalid reduction parameter was given
"""
if reduction == 'elementwise_mean':
return torch.mean(to_reduce)
if reduction == 'none':
return to_reduce
if reduction == 'sum':
return torch.sum(to_reduce)
raise ValueError('Reduction parameter unknown.')
def class_reduce(num: torch.Tensor,
denom: torch.Tensor,
weights: torch.Tensor,
class_reduction: str = 'none') -> torch.Tensor:
"""
Function used to reduce classification metrics of the form `num / denom * weights`.
For example for calculating standard accuracy the num would be number of
true positives per class, denom would be the support per class, and weights
would be a tensor of 1s
Args:
num: numerator tensor
decom: denominator tensor
weights: weights for each class
class_reduction: reduction method for multiclass problems
- ``'micro'``: calculate metrics globally (default)
- ``'macro'``: calculate metrics for each label, and find their unweighted mean.
- ``'weighted'``: calculate metrics for each label, and find their weighted mean.
- ``'none'``: returns calculated metric per class
"""
valid_reduction = ('micro', 'macro', 'weighted', 'none')
if class_reduction == 'micro':
return torch.sum(num) / torch.sum(denom)
# For the rest we need to take care of instances where the denom can be 0
# for some classes which will produce nans for that class
fraction = num / denom
fraction[fraction != fraction] = 0
if class_reduction == 'macro':
return torch.mean(fraction)
elif class_reduction == 'weighted':
return torch.sum(fraction * (weights / torch.sum(weights)))
elif class_reduction == 'none':
return fraction
raise ValueError(f'Reduction parameter {class_reduction} unknown.'
f' Choose between one of these: {valid_reduction}')

View File

@ -1,325 +0,0 @@
from typing import Sequence
import torch
from torch.nn import functional as F
from pytorch_lightning.metrics.functional.reduction import reduce
def mse(
pred: torch.Tensor,
target: torch.Tensor,
reduction: str = 'elementwise_mean',
return_state: bool = False
) -> torch.Tensor:
"""
Computes mean squared error
Args:
pred: estimated labels
target: ground truth labels
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
return_state: returns a internal state that can be ddp reduced
before doing the final calculation
Return:
Tensor with MSE
Example:
>>> x = torch.tensor([0., 1, 2, 3])
>>> y = torch.tensor([0., 1, 2, 2])
>>> mse(x, y)
tensor(0.2500)
"""
mse = F.mse_loss(pred, target, reduction='none')
if return_state:
return {'squared_error': mse.sum(), 'n_observations': torch.tensor(mse.numel())}
mse = reduce(mse, reduction=reduction)
return mse
def rmse(
pred: torch.Tensor,
target: torch.Tensor,
reduction: str = 'elementwise_mean',
return_state: bool = False
) -> torch.Tensor:
"""
Computes root mean squared error
Args:
pred: estimated labels
target: ground truth labels
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
return_state: returns a internal state that can be ddp reduced
before doing the final calculation
Return:
Tensor with RMSE
>>> x = torch.tensor([0., 1, 2, 3])
>>> y = torch.tensor([0., 1, 2, 2])
>>> rmse(x, y)
tensor(0.5000)
"""
mean_squared_error = mse(pred, target, reduction=reduction)
if return_state:
return {'squared_error': mean_squared_error.sum(),
'n_observations': torch.tensor(mean_squared_error.numel())}
return torch.sqrt(mean_squared_error)
def mae(
pred: torch.Tensor,
target: torch.Tensor,
reduction: str = 'elementwise_mean',
return_state: bool = False
) -> torch.Tensor:
"""
Computes mean absolute error
Args:
pred: estimated labels
target: ground truth labels
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
return_state: returns a internal state that can be ddp reduced
before doing the final calculation
Return:
Tensor with MAE
Example:
>>> x = torch.tensor([0., 1, 2, 3])
>>> y = torch.tensor([0., 1, 2, 2])
>>> mae(x, y)
tensor(0.2500)
"""
mae = F.l1_loss(pred, target, reduction='none')
if return_state:
return {'absolute_error': mae.sum(), 'n_observations': torch.tensor(mae.numel())}
mae = reduce(mae, reduction=reduction)
return mae
def rmsle(
pred: torch.Tensor,
target: torch.Tensor,
reduction: str = 'elementwise_mean'
) -> torch.Tensor:
"""
Computes root mean squared log error
Args:
pred: estimated labels
target: ground truth labels
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
Return:
Tensor with RMSLE
Example:
>>> x = torch.tensor([0., 1, 2, 3])
>>> y = torch.tensor([0., 1, 2, 2])
>>> rmsle(x, y)
tensor(0.1438)
"""
rmsle = rmse(torch.log(pred + 1), torch.log(target + 1), reduction=reduction)
return rmsle
def psnr(
pred: torch.Tensor,
target: torch.Tensor,
data_range: float = None,
base: float = 10.0,
reduction: str = 'elementwise_mean',
return_state: bool = False
) -> torch.Tensor:
"""
Computes the peak signal-to-noise ratio
Args:
pred: estimated signal
target: groun truth signal
data_range: the range of the data. If None, it is determined from the data (max - min)
base: a base of a logarithm to use (default: 10)
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
return_state: returns a internal state that can be ddp reduced
before doing the final calculation
Return:
Tensor with PSNR score
Example:
>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> psnr(pred, target)
tensor(2.5527)
"""
if data_range is None:
data_range = target.max() - target.min()
else:
data_range = torch.tensor(float(data_range))
if return_state:
return {'data_range': data_range,
'sum_squared_error': F.mse_loss(pred, target, reduction='none').sum(),
'n_obs': torch.tensor(target.numel())}
mse_score = mse(pred.view(-1), target.view(-1), reduction=reduction)
psnr_base_e = 2 * torch.log(data_range) - torch.log(mse_score)
psnr = psnr_base_e * (10 / torch.log(torch.tensor(base)))
return psnr
def _gaussian_kernel(channel, kernel_size, sigma, device):
def _gaussian(kernel_size, sigma, device):
gauss = torch.arange(
start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2,
step=1,
dtype=torch.float32,
device=device
)
gauss = torch.exp(-gauss.pow(2) / (2 * pow(sigma, 2)))
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)
gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], device)
gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], device)
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y)
return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])
def ssim(
pred: torch.Tensor,
target: torch.Tensor,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
data_range: float = None,
k1: float = 0.01,
k2: float = 0.03
) -> torch.Tensor:
"""
Computes Structual Similarity Index Measure
Args:
pred: estimated image
target: ground truth image
kernel_size: size of the gaussian kernel (default: (11, 11))
sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5))
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
k1: Parameter of SSIM. Default: 0.01
k2: Parameter of SSIM. Default: 0.03
Return:
Tensor with SSIM score
Example:
>>> pred = torch.rand([16, 1, 16, 16])
>>> target = pred * 0.75
>>> ssim(pred, target)
tensor(0.9219)
"""
if pred.dtype != target.dtype:
raise TypeError(
"Expected `pred` and `target` to have the same data type."
f" Got pred: {pred.dtype} and target: {target.dtype}."
)
if pred.shape != target.shape:
raise ValueError(
"Expected `pred` and `target` to have the same shape."
f" Got pred: {pred.shape} and target: {target.shape}."
)
if len(pred.shape) != 4 or len(target.shape) != 4:
raise ValueError(
"Expected `pred` and `target` to have BxCxHxW shape."
f" Got pred: {pred.shape} and target: {target.shape}."
)
if len(kernel_size) != 2 or len(sigma) != 2:
raise ValueError(
"Expected `kernel_size` and `sigma` to have the length of two."
f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}."
)
if any(x % 2 == 0 or x <= 0 for x in kernel_size):
raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.")
if any(y <= 0 for y in sigma):
raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.")
if data_range is None:
data_range = max(pred.max() - pred.min(), target.max() - target.min())
C1 = pow(k1 * data_range, 2)
C2 = pow(k2 * data_range, 2)
device = pred.device
channel = pred.size(1)
kernel = _gaussian_kernel(channel, kernel_size, sigma, device)
# Concatenate
# pred for mu_pred
# target for mu_target
# pred * pred for sigma_pred
# target * target for sigma_target
# pred * target for sigma_pred_target
input_list = torch.cat([pred, target, pred * pred, target * target, pred * target]) # (5 * B, C, H, W)
outputs = F.conv2d(input_list, kernel, groups=channel)
output_list = [outputs[x * pred.size(0): (x + 1) * pred.size(0)] for x in range(len(outputs))]
mu_pred_sq = output_list[0].pow(2)
mu_target_sq = output_list[1].pow(2)
mu_pred_target = output_list[0] * output_list[1]
sigma_pred_sq = output_list[2] - mu_pred_sq
sigma_target_sq = output_list[3] - mu_target_sq
sigma_pred_target = output_list[4] - mu_pred_target
UPPER = 2 * sigma_pred_target + C2
LOWER = sigma_pred_sq + sigma_target_sq + C2
ssim_idx = ((2 * mu_pred_target + C1) * UPPER) / ((mu_pred_sq + mu_target_sq + C1) * LOWER)
return reduce(ssim_idx, reduction)

View File

@ -1,46 +0,0 @@
import torch
def embedding_similarity(
batch: torch.Tensor,
similarity: str = 'cosine',
reduction: str = 'none',
zero_diagonal: bool = True
) -> torch.Tensor:
"""
Computes representation similarity
Example:
>>> embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]])
>>> embedding_similarity(embeddings)
tensor([[0.0000, 1.0000, 0.9759],
[1.0000, 0.0000, 0.9759],
[0.9759, 0.9759, 0.0000]])
Args:
batch: (batch, dim)
similarity: 'dot' or 'cosine'
reduction: 'none', 'sum', 'mean' (all along dim -1)
zero_diagonal: if True, the diagonals are set to zero
Return:
A square matrix (batch, batch) with the similarity scores between all elements
If sum or mean are used, then returns (b, 1) with the reduced value for each row
"""
if similarity == 'cosine':
norm = torch.norm(batch, p=2, dim=1)
batch = batch / norm.unsqueeze(1)
sqr_mtx = batch.mm(batch.transpose(1, 0))
if zero_diagonal:
sqr_mtx = sqr_mtx.fill_diagonal_(0)
if reduction == 'mean':
sqr_mtx = sqr_mtx.mean(dim=-1)
if reduction == 'sum':
sqr_mtx = sqr_mtx.sum(dim=-1)
return sqr_mtx

View File

@ -1,262 +1,220 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from abc import ABC, abstractmethod
from typing import Any, Mapping, Optional, Sequence
import numbers
from typing import Any, Callable, Optional, Union
from collections.abc import Mapping, Sequence
from collections import namedtuple
from copy import deepcopy
import os
import torch
from torch import nn
import numpy as np
from pytorch_lightning.metrics.converters import (
at_least_1d,
gather_all_tensors_if_available,
convert_to_tensor,
convert_to_numpy,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.distributed import gather_all_tensors_if_available
from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum
class Metric(DeviceDtypeModuleMixin, nn.Module, ABC):
class Metric(nn.Module, ABC):
"""
Abstract base class for metric implementation.
Base class for all metrics present in the Metrics API.
Should be used to implement metrics that
Implements ``add_state()``, ``forward()``, ``reset()`` and a few other things to
handle distributed synchronization and per step metric computation.
1. Return multiple Outputs
2. Handle their own DDP sync
Override ``update()`` and ``compute()`` functions to implement your own metric. Use
``add_state()`` to register metric state variables which keep track of state on each
call of ``update()`` and are synchronized across processes when ``compute()`` is called.
Metric hooks that can be implemented are
Note:
Metric state variables can either be ``torch.Tensors`` or an empty list which can we used
to store `torch.Tensors``.
* input_convert: pre-forward hook that takes care of input conversion
* output_convert: post-forward hook that takes care of output convertion
* ddp_reduce: implementation of ddp sync + aggregation, default is ddp_sync + aggregate
* compute: post-ddp sync for additional metric computations
``ddp_reduce`` by default calls the following methods, which can also be overwritten if necessary.
* ddp_sync: implements how values should be synced across ddp-processes. Defaults to gather all.
* aggregate: implement how values should be aggregated (defaults to mean).
Call order
input_convert -> forward -> output_convert -> ddp_reduce (per default being ddp_sync -> aggregate) -> compute
Note:
Different metrics only override ``update()`` and not ``forward()``. A call to ``update()``
is valid, but it won't return the metric value at the current step. A call to ``forward()``
automatically calls ``update()`` and also return the metric value at the current step.
Args:
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
ddp_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
"""
def __init__(self, name: str, reduce_group: Optional[Any] = None):
"""
Args:
name: the metric's name
reduce_group: the process group for DDP reduces (only needed for DDP training).
Defaults to all processes (world)
"""
def __init__(
self,
compute_on_step: bool = True,
ddp_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__()
self.name = name
self._dtype = torch.get_default_dtype()
self._device = torch.device("cpu")
self.reduce_group = reduce_group
self.ddp_sync_on_step = ddp_sync_on_step
self.compute_on_step = compute_on_step
self.process_group = process_group
self._to_sync = True
# Buffer for holding aggregated state after each batch
self._step_vals = []
self.update = self._wrap_update(self.update)
self.compute = self._wrap_compute(self.compute)
self._computed = None
# Register hooks
self.register_forward_pre_hook(self.input_convert)
self.register_forward_hook(self.output_convert)
self.register_forward_hook(self.ddp_reduce)
self.register_forward_hook(self.compute)
# initialize state
self._reductions = {}
self._defaults = {}
@staticmethod
def input_convert(self, data: Any):
def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None):
"""
Implement how the inputs should be casted before calling forward
Adds metric state variable. Only used by subclasses.
Args:
data: input to forward method
name: The name of the state variable. The variable will then be accessible at ``self.name``.
default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be
reset to this value when ``self.reset()`` is called.
dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode.
If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``,
and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom
function in this parameter.
Note:
Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes.
However, there won't be any reduction function applied to the synchronized metric state.
The metric states would be synced as follows
- If the metric state is ``torch.Tensor``, the synced value will be a stacked ``torch.Tensor`` across
the process dimension if the metric state was a ``torch.Tensor``. The original ``torch.Tensor`` metric
state retains dimension and hence the synchronized output will be of shape ``(num_process, ...)``.
- If the metric state is a ``list``, the synced value will be a ``list`` containing the
combined elements from all processes.
Note:
When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow
the format discussed in the above note.
Returns:
casted data
"""
return data
if not isinstance(default, torch.Tensor) or (isinstance(default, list) and len(default) != 0):
raise ValueError(
"state variable must be a tensor or any empty list (where you can append tensors)"
)
if dist_reduce_fx == "sum":
dist_reduce_fx = dim_zero_sum
elif dist_reduce_fx == "mean":
dist_reduce_fx = dim_zero_mean
elif dist_reduce_fx == "cat":
dist_reduce_fx = dim_zero_cat
elif dist_reduce_fx is not None and not isinstance(dist_reduce_fx, Callable):
raise ValueError(
"`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]"
)
if isinstance(default, torch.Tensor):
self.register_buffer(name, default)
else:
setattr(self, name, default)
self._defaults[name] = deepcopy(default)
self._reductions[name] = dist_reduce_fx
@abstractmethod
def forward(self, *args, **kwargs):
"""
Implements the actual metric computation.
Returns:
metric value or metric state
Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True.
"""
raise NotImplementedError
# add current step
self.update(*args, **kwargs)
@staticmethod
def output_convert(self, data: Any, output: Any):
if self.compute_on_step:
self._to_sync = self.ddp_sync_on_step
# save context before switch
self._cache = {attr: getattr(self, attr) for attr in self._defaults.keys()}
# call reset, update, compute, on single batch
self.reset()
self.update(*args, **kwargs)
result = self.compute()
# restore context
for attr, val in self._cache.items():
setattr(self, attr, val)
self._to_sync = True
self._computed = None
return result
def _sync_dist(self):
input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()}
output_dict = apply_to_collection(
input_dict,
torch.Tensor,
gather_all_tensors_if_available,
group=self.process_group,
)
for attr, reduction_fn in self._reductions.items():
# pre-processing ops (stack or flatten for inputs)
if isinstance(output_dict[attr][0], torch.Tensor):
output_dict[attr] = torch.stack(output_dict[attr])
elif isinstance(output_dict[attr][0], list):
output_dict[attr] = _flatten(output_dict[attr])
assert isinstance(reduction_fn, (Callable, None))
reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr]
setattr(self, attr, reduced)
def _wrap_update(self, update):
@functools.wraps(update)
def wrapped_func(*args, **kwargs):
self._computed = None
return update(*args, **kwargs)
return wrapped_func
def _wrap_compute(self, compute):
@functools.wraps(compute)
def wrapped_func(*args, **kwargs):
# return cached value
if self._computed is not None:
return self._computed
if (
self._to_sync
and torch.distributed.is_available() # noqa: W503
and torch.distributed.is_initialized() # noqa: W503
):
self._sync_dist()
self._computed = compute(*args, **kwargs)
self.reset()
return self._computed
return wrapped_func
@abstractmethod
def update(self) -> None: # pylint: disable=E0202
"""
Implement how outputs from forward should be casted
Args:
data: input to forward method
output: output from forward method
Returns:
casted outputs
Override this method to update the state variables of your metric class.
"""
return apply_to_collection(output, (torch.Tensor, np.ndarray), at_least_1d)
pass
def ddp_sync(self, tensor: Any):
@abstractmethod
def compute(self): # pylint: disable=E0202
"""
Implement how the outputs from forward should be synced
(per default just gathers all of them and adds them to self._step_vals)
Args:
tensor: tensor to sync
Returns:
synced output
Override this method to compute the final metric value from state variables
synchronized across the distributed backend.
"""
gathered_tensors = apply_to_collection(tensor, torch.Tensor, gather_all_tensors_if_available, self.reduce_group)
return gathered_tensors
@staticmethod
def ddp_reduce(self, data: Any, output: Any):
"""
Implement how the outputs from forward should be synced and reduced across nodes
Args:
data: input to forward method
output: output from the `output_convert` hook
Returns:
synced output
"""
synced = self.ddp_sync(output)
agg_val = self.aggregate(synced)
self._step_vals.append(agg_val)
return agg_val
def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor:
"""
Implement aggregation of values on the same device
Args:
tensors: the values to be aggregated
Returns:
aggregated values
"""
# single tensor
if len(tensors) == 1:
tensors = tensors[0]
if isinstance(tensors, Mapping):
return {k: _stack_and_agg(tensors[k]) for k in tensors.keys()}
if isinstance(tensors, list):
return _stack_and_agg(tensors)
if isinstance(tensors, tuple):
return tensors
if isinstance(tensors, torch.Tensor):
return _stack_and_agg(tensors)
# multiple tensors (from aggregation over batches)
if isinstance(tensors[0], Mapping):
return {k: torch.stack([tensor[k] for tensor in tensors]).sum(0) for k in tensors[0].keys()}
if isinstance(tensors[0], Sequence):
return tuple([torch.stack(tmp).sum(0) for tmp in zip(*tensors)])
if isinstance(tensors[0], torch.Tensor):
return torch.stack(tensors).sum(0)
raise TypeError("unknown metric value format to aggregate")
@staticmethod
def compute(self, data: Any, output: Any):
"""
Implement additionally metric computations to be done after the aggregation
Args:
data: input to forward method
output: output from the `aggregate` hook
Returns:
final metric value
"""
return output
@property
def aggregated(self) -> torch.Tensor:
aggr = self.aggregate(*self._step_vals if len(self._step_vals) > 1 else self._step_vals)
self.reset()
return self.compute(self, None, aggr)
pass
def reset(self):
self._step_vals = []
def _stack_and_agg(tensors):
""" Utility function for stacking and aggregating tensors """
if isinstance(tensors, list):
return torch.sum(torch.stack([t for t in tensors]), 0)
return tensors.squeeze() if tensors.numel() == 1 else tensors
class TensorMetric(Metric):
"""
Base class for metric implementation operating directly on tensors.
All inputs and outputs will be casted to tensors if necessary.
Already handles DDP sync and input/output conversions.
"""
@staticmethod
def input_convert(self, data: Any):
data = apply_to_collection(
data, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device
)
return super(TensorMetric, self).input_convert(self, data)
@staticmethod
def output_convert(self, data: Any, output: Any):
output = apply_to_collection(
output, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device
)
return super(TensorMetric, self).output_convert(self, data, output)
class NumpyMetric(Metric):
"""
Base class for metric implementation operating on numpy arrays.
All inputs will be casted to numpy if necessary and all outputs will
be casted to tensors if necessary.
Already handles DDP sync and input/output conversions.
"""
@staticmethod
def input_convert(self, data: Any):
data = apply_to_collection(data, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)
return super(NumpyMetric, self).input_convert(self, data)
@staticmethod
def output_convert(self, data: Any, output: Any):
output = apply_to_collection(
output, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device
)
return super(NumpyMetric, self).output_convert(self, data, output)
"""
This method automatically resets the metric state variables to their default value.
"""
for attr, default in self._defaults.items():
current_val = getattr(self, attr)
if isinstance(current_val, torch.Tensor):
setattr(self, attr, deepcopy(default).to(current_val.device))
else:
setattr(self, attr, deepcopy(default))

View File

@ -1,60 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from pytorch_lightning.metrics.functional.nlp import bleu_score
from pytorch_lightning.metrics.metric import Metric
class BLEUScore(Metric):
"""
Calculate BLEU score of machine translated text with one or more references.
Example:
>>> translate_corpus = ['the cat is on the mat'.split()]
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
>>> metric = BLEUScore()
>>> metric(translate_corpus, reference_corpus)
tensor(0.7598)
"""
def __init__(self, n_gram: int = 4, smooth: bool = False):
"""
Args:
n_gram: Gram value ranged from 1 to 4 (Default 4)
smooth: Whether or not to apply smoothing Lin et al. 2004
"""
super().__init__(name="bleu")
self.n_gram = n_gram
self.smooth = smooth
def forward(self, translate_corpus: list, reference_corpus: list) -> torch.Tensor:
"""
Actual metric computation
Args:
translate_corpus: An iterable of machine translated corpus
reference_corpus: An iterable of iterables of reference corpus
Return:
torch.Tensor: BLEU Score
"""
return bleu_score(
translate_corpus=translate_corpus,
reference_corpus=reference_corpus,
n_gram=self.n_gram,
smooth=self.smooth,
).to(self.device, self.dtype)

View File

@ -1,361 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence, Any
import torch
from pytorch_lightning.metrics.functional.regression import (
mae,
mse,
psnr,
rmse,
rmsle,
ssim
)
from pytorch_lightning.metrics.metric import Metric
class MSE(Metric):
"""
Computes the mean squared loss.
Example:
>>> pred = torch.tensor([0., 1, 2, 3])
>>> target = torch.tensor([0., 1, 2, 2])
>>> metric = MSE()
>>> metric(pred, target)
tensor(0.2500)
"""
def __init__(
self,
reduction: str = 'elementwise_mean',
):
"""
Args:
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
"""
super().__init__(name='mse')
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the mse loss.
"""
return mse(pred, target, return_state=True)
@staticmethod
def compute(self, data: Any, output: Any):
sse, n = output['squared_error'], output['n_observations']
return sse / n
class RMSE(Metric):
"""
Computes the root mean squared loss.
Example:
>>> pred = torch.tensor([0., 1, 2, 3])
>>> target = torch.tensor([0., 1, 2, 2])
>>> metric = RMSE()
>>> metric(pred, target)
tensor(0.5000)
"""
def __init__(
self,
reduction: str = 'elementwise_mean',
):
"""
Args:
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
"""
super().__init__(name='rmse')
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the rmse loss.
"""
return rmse(pred, target, reduction='none', return_state=True)
@staticmethod
def compute(self, data: Any, output: Any):
""" Squaring needs to happend after ddp sync """
sse, n = output['squared_error'], output['n_observations']
return torch.sqrt(sse / n)
class MAE(Metric):
"""
Computes the mean absolute loss or L1-loss.
Example:
>>> pred = torch.tensor([0., 1, 2, 3])
>>> target = torch.tensor([0., 1, 2, 2])
>>> metric = MAE()
>>> metric(pred, target)
tensor(0.2500)
"""
def __init__(
self,
reduction: str = 'elementwise_mean',
):
"""
Args:
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
"""
super().__init__(name='mae')
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the mae loss.
"""
return mae(pred, target, return_state=True)
@staticmethod
def compute(self, data: Any, output: Any):
sae, n = output['absolute_error'], output['n_observations']
return sae / n
class RMSLE(Metric):
"""
Computes the root mean squared log loss.
Example:
>>> pred = torch.tensor([0., 1, 2, 3])
>>> target = torch.tensor([0., 1, 2, 2])
>>> metric = RMSLE()
>>> metric(pred, target)
tensor(0.1438)
"""
def __init__(
self,
reduction: str = 'elementwise_mean',
):
"""
Args:
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
"""
super().__init__(name='rmsle')
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with the rmsle loss.
"""
return mse(torch.log(pred + 1), torch.log(target + 1),
self.reduction, return_state=True)
@staticmethod
def compute(self, data: Any, output: Any):
""" Squaring needs to happend after ddp sync """
sse, n = output['squared_error'], output['n_observations']
return torch.sqrt(sse / n)
class PSNR(Metric):
"""
Computes the peak signal-to-noise ratio
Example:
>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
>>> metric = PSNR()
>>> metric(pred, target)
tensor(2.5527)
"""
def __init__(
self,
data_range: float = None,
base: int = 10,
reduction: str = 'elementwise_mean'
):
"""
Args:
data_range: the range of the data. If None, it is determined from the data (max - min)
base: a base of a logarithm to use (default: 10)
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
"""
super().__init__(name='psnr')
self.data_range = data_range
self.base = float(base)
self.reduction = reduction
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: predicted labels
target: ground truth labels
Return:
A Tensor with psnr score.
"""
return psnr(pred, target, self.data_range, self.base, self.reduction, return_state=True)
def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor:
""" Special aggregation function as the data range needs to be correctly synced """
if len(tensors) == 1:
tensors = tensors[0]
output = {'data_range': torch.stack([t for t in tensors['data_range']]).max()}
output.update({k: torch.stack([t for t in tensors[k]]).sum(0) for k in tensors.keys() if k != 'data_range'})
return output
output = {'data_range': torch.stack([tensor['data_range'] for tensor in tensors]).max()}
output.update({k: torch.stack([tensor[k] for tensor in tensors]).sum(0) for k in tensors[0].keys() if k != 'data_range'})
return output
@staticmethod
def compute(self, data: Any, output: Any):
"""
Compute final value based on the synced data_range, sum of squared errors
and number of samples.
Args:
data: input to forward method
output: output from the `aggregate` hook
Returns:
final metric value
"""
sse, n, data_range = output['sum_squared_error'], output['n_obs'], output['data_range']
psnr_base_e = 2 * torch.log(data_range) - torch.log(sse / n)
psnr = psnr_base_e * (10 / torch.log(torch.tensor(self.base)))
return psnr
class SSIM(Metric):
"""
Computes Structual Similarity Index Measure
Example:
>>> pred = torch.rand([16, 1, 16, 16])
>>> target = pred * 0.75
>>> metric = SSIM()
>>> metric(pred, target)
tensor(0.9219)
"""
def __init__(
self,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: str = "elementwise_mean",
data_range: float = None,
k1: float = 0.01,
k2: float = 0.03
):
"""
Args:
kernel_size: Size of the gaussian kernel (default: (11, 11))
sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5))
reduction: a method to reduce metric score over labels.
- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
k1: Parameter of SSIM. Default: 0.01
k2: Parameter of SSIM. Default: 0.03
"""
super().__init__(name="ssim")
self.kernel_size = kernel_size
self.sigma = sigma
self.reduction = reduction
self.data_range = data_range
self.k1 = k1
self.k2 = k2
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
pred: Estimated image
target: Ground truth image
Return:
A Tensor with SSIM score.
"""
return ssim(pred, target, self.kernel_size, self.sigma, self.reduction, self.data_range, self.k1, self.k2)

View File

@ -0,0 +1,3 @@
from pytorch_lightning.metrics.regression.mean_squared_error import MeanSquaredError
from pytorch_lightning.metrics.regression.mean_absolute_error import MeanAbsoluteError
from pytorch_lightning.metrics.regression.mean_squared_log_error import MeanSquaredLogError

View File

@ -0,0 +1,54 @@
import torch
from typing import Any, Callable, Optional, Union
from pytorch_lightning.metrics.metric import Metric
class MeanAbsoluteError(Metric):
"""
Computes mean absolute error.
Example:
>>> from pytorch_lightning.metrics import MeanAbsoluteError
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> mean_absolute_error = MeanAbsoluteError()
>>> mean_absolute_error(preds, target)
tensor(0.5000)
"""
def __init__(
self,
compute_on_step: bool = True,
ddp_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
compute_on_step=compute_on_step,
ddp_sync_on_step=ddp_sync_on_step,
process_group=process_group,
)
self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.
Args:
preds: Predictions from model
target: Ground truth values
"""
assert preds.shape == target.shape
abs_error = torch.abs(preds - target)
self.sum_abs_error += torch.sum(abs_error)
self.total += target.numel()
def compute(self):
"""
Computes mean absolute error over state.
"""
return self.sum_abs_error / self.total

View File

@ -0,0 +1,55 @@
import torch
from typing import Any, Callable, Optional, Union
from pytorch_lightning.metrics.metric import Metric
class MeanSquaredError(Metric):
"""
Computes mean squared error.
Example:
>>> from pytorch_lightning.metrics import MeanSquaredError
>>> target = torch.tensor([2.5, 5.0, 4.0, 8.0])
>>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0])
>>> mean_squared_error = MeanSquaredError()
>>> mean_squared_error(preds, target)
tensor(0.8750)
"""
def __init__(
self,
compute_on_step: bool = True,
ddp_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
compute_on_step=compute_on_step,
ddp_sync_on_step=ddp_sync_on_step,
process_group=process_group,
)
self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.
Args:
preds: Predictions from model
target: Ground truth values
"""
assert preds.shape == target.shape
squared_error = torch.pow(preds - target, 2)
self.sum_squared_error += torch.sum(squared_error)
self.total += target.numel()
def compute(self):
"""
Computes mean squared error over state.
"""
return self.sum_squared_error / self.total

View File

@ -0,0 +1,55 @@
import torch
from typing import Any, Callable, Optional, Union
from pytorch_lightning.metrics.metric import Metric
class MeanSquaredLogError(Metric):
"""
Computes mean squared logarithmic error.
Example:
>>> from pytorch_lightning.metrics import MeanSquaredLogError
>>> target = torch.tensor([2.5, 5, 4, 8])
>>> preds = torch.tensor([3, 5, 2.5, 7])
>>> mean_squared_log_error = MeanSquaredLogError()
>>> mean_squared_log_error(preds, target)
tensor(0.0397)
"""
def __init__(
self,
compute_on_step: bool = True,
ddp_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
compute_on_step=compute_on_step,
ddp_sync_on_step=ddp_sync_on_step,
process_group=process_group,
)
self.add_state("sum_squared_log_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.
Args:
preds: Predictions from model
target: Ground truth values
"""
assert preds.shape == target.shape
squared_log_error = torch.pow(torch.log1p(preds) - torch.log1p(target), 2)
self.sum_squared_log_error += torch.sum(squared_log_error)
self.total += target.numel()
def compute(self):
"""
Compute mean squared logarithmic error over state.
"""
return self.sum_squared_log_error / self.total

View File

@ -1,85 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity
from pytorch_lightning.metrics.metric import TensorMetric
from pytorch_lightning.utilities import rank_zero_warn
class EmbeddingSimilarity(TensorMetric):
"""
Computes similarity between embeddings
Example:
>>> embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]])
>>> embedding_similarity(embeddings)
tensor([[0.0000, 1.0000, 0.9759],
[1.0000, 0.0000, 0.9759],
[0.9759, 0.9759, 0.0000]])
"""
def __init__(
self,
similarity: str = 'cosine',
zero_diagonal: bool = True,
reduction: str = 'mean',
reduce_group: Any = None
):
"""
Args:
similarity: 'dot' or 'cosine'
reduction: 'none', 'sum', 'mean' (all along dim -1)
zero_diagonal: if True, the diagonals are set to zero
reduce_group: the process group to reduce metric results from DDP
"""
super().__init__(name='embedding_similarity',
reduce_group=reduce_group)
assert similarity in ('dot', 'cosine')
self.similarity = similarity
isinstance(zero_diagonal, bool)
self.zero_diagonal = zero_diagonal
assert reduction in ('none', 'sum', 'mean')
self.reduction = reduction
rank_zero_warn('Please note that Metric `EmbeddingSimilarity` does not support aggregation.')
def forward(self, batch: torch.Tensor) -> torch.Tensor:
"""
Actual metric computation
Args:
batch: tensor containing embeddings with shape (batch_size, dim)
Return:
A square matrix (batch, batch) with the similarity scores between all elements
If sum or mean are used, then returns (b, 1) with the reduced value for each row
"""
return embedding_similarity(batch,
similarity=self.similarity,
zero_diagonal=self.zero_diagonal,
reduction=self.reduction)
@staticmethod
def ddp_reduce(self, data: Any, output: Any):
""" reduction for this metric does not make sense """
return output
@property
def aggregated(self):
raise ValueError('Metric `EmbeddingSimilarity` does not support aggregation.')

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,19 @@
import torch
from typing import Any, Callable, Optional, Union
def dim_zero_cat(x):
return torch.cat(x, dim=0)
def dim_zero_sum(x):
return torch.sum(x, dim=0)
def dim_zero_mean(x):
return torch.mean(x, dim=0)
def _flatten(x):
return [item for sublist in x for item in sublist]

View File

@ -16,7 +16,15 @@ import os
import warnings
from functools import wraps
import torch
from pytorch_lightning import _logger as log
from typing import Union, Optional, Any
if torch.distributed.is_available():
from torch.distributed import ReduceOp
else:
class ReduceOp:
SUM = None
def rank_zero_only(fn):
@ -63,3 +71,71 @@ def find_free_network_port() -> int:
port = s.getsockname()[1]
s.close()
return port
def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None):
"""
Function to gather all tensors from several ddp processes onto a list that
is broadcasted to all processes
Args:
result: the value to sync
group: the process group to gather results from. Defaults to all processes (world)
Return:
gathered_result: list with size equal to the process group where
gathered_result[i] corresponds to result tensor from process i
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
if group is None:
group = torch.distributed.group.WORLD
world_size = torch.distributed.get_world_size(group)
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
# sync and broadcast all
torch.distributed.barrier(group=group)
torch.distributed.all_gather(gathered_result, result, group)
result = gathered_result
return result
def sync_ddp_if_available(
result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> torch.Tensor:
"""
Function to reduce the tensors from several ddp processes to one master process
Args:
result: the value to sync and reduce (typically tensor or number)
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum.
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
Return:
reduced value
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
divide_by_world_size = False
if group is None:
group = torch.distributed.group.WORLD
if reduce_op is None:
reduce_op = torch.distributed.ReduceOp.SUM
elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
reduce_op = torch.distributed.ReduceOp.SUM
divide_by_world_size = True
# sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False)
if divide_by_world_size:
result = result / torch.distributed.get_world_size(group)
return result

View File

@ -0,0 +1,5 @@
import os
from tests.metrics.utils import compute_batch, setup_ddp
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE
from tests.metrics.test_metric import Dummy

View File

@ -0,0 +1,155 @@
import os
import pytest
import torch
import os
import numpy as np
from collections import namedtuple
from pytorch_lightning.metrics.classification.accuracy import Accuracy
from sklearn.metrics import accuracy_score
from tests.metrics.utils import compute_batch, setup_ddp
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE
torch.manual_seed(42)
# global vars
num_classes = 5
threshold = 0.5
extra_dim = 3
Input = namedtuple('Input', ["preds", "target"])
def test_accuracy_invalid_shape():
with pytest.raises(ValueError):
acc = Accuracy()
acc.update(preds=torch.rand(1), target=torch.rand(1, 2, 3))
_binary_prob_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))
)
def _binary_prob_sk_metric(preds, target):
sk_preds = (preds.view(-1).numpy() >= threshold).astype(np.uint8)
sk_target = target.view(-1).numpy()
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
_binary_inputs = Input(
preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE,)),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE,))
)
def _binary_sk_metric(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
_multilabel_prob_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, num_classes))
)
def _multilabel_prob_sk_metric(preds, target):
sk_preds = (preds.view(-1).numpy() >= threshold).astype(np.uint8)
sk_target = target.view(-1).numpy()
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
_multilabel_inputs = Input(
preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, num_classes)),
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, num_classes))
)
def _multilabel_sk_metric(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
_multiclass_prob_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes),
target=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE))
)
def _multiclass_prob_sk_metric(preds, target):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy()
sk_target = target.view(-1).numpy()
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
_multiclass_inputs = Input(
preds=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE)),
target=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE))
)
def _multiclass_sk_metric(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
_multidim_multiclass_prob_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes, extra_dim),
target=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE, extra_dim))
)
def _multidim_multiclass_prob_sk_metric(preds, target):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy()
sk_target = target.view(-1).numpy()
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
_multidim_multiclass_inputs = Input(
preds=torch.randint(high=num_classes, size=(NUM_BATCHES, extra_dim, BATCH_SIZE)),
target=torch.randint(high=num_classes, size=(NUM_BATCHES, extra_dim, BATCH_SIZE))
)
def _multidim_multiclass_sk_metric(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("ddp_sync_on_step", [True, False])
@pytest.mark.parametrize("preds, target, sk_metric", [
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric),
(_binary_inputs.preds, _binary_inputs.target, _binary_sk_metric),
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _multilabel_prob_sk_metric),
(_multilabel_inputs.preds, _multilabel_inputs.target, _multilabel_sk_metric),
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _multiclass_prob_sk_metric),
(_multiclass_inputs.preds, _multiclass_inputs.target, _multiclass_sk_metric),
(
_multidim_multiclass_prob_inputs.preds,
_multidim_multiclass_prob_inputs.target,
_multidim_multiclass_prob_sk_metric
),
(
_multidim_multiclass_inputs.preds,
_multidim_multiclass_inputs.target,
_multidim_multiclass_sk_metric
)
])
def test_accuracy(ddp, ddp_sync_on_step, preds, target, sk_metric):
compute_batch(preds, target, Accuracy, sk_metric, ddp_sync_on_step, ddp, metric_args={"threshold": threshold})

View File

@ -1,485 +0,0 @@
from functools import partial
import pytest
import torch
from sklearn.metrics import (
accuracy_score as sk_accuracy,
jaccard_score as sk_jaccard_score,
precision_score as sk_precision,
recall_score as sk_recall,
f1_score as sk_f1_score,
fbeta_score as sk_fbeta_score,
confusion_matrix as sk_confusion_matrix,
roc_curve as sk_roc_curve,
roc_auc_score as sk_roc_auc_score,
precision_recall_curve as sk_precision_recall_curve
)
from pytorch_lightning import seed_everything
from pytorch_lightning.metrics.functional.classification import (
to_onehot,
to_categorical,
get_num_classes,
stat_scores,
stat_scores_multiple_classes,
accuracy,
confusion_matrix,
precision,
recall,
fbeta_score,
f1_score,
_binary_clf_curve,
dice_score,
average_precision,
auroc,
precision_recall_curve,
roc,
auc,
iou,
)
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric', 'only_binary'], [
pytest.param(sk_accuracy, accuracy, False, id='accuracy'),
pytest.param(partial(sk_jaccard_score, average='macro'), iou, False, id='iou'),
pytest.param(partial(sk_precision, average='micro'), precision, False, id='precision'),
pytest.param(partial(sk_recall, average='micro'), recall, False, id='recall'),
pytest.param(partial(sk_f1_score, average='micro'), f1_score, False, id='f1_score'),
pytest.param(partial(sk_fbeta_score, average='micro', beta=2),
partial(fbeta_score, beta=2), False, id='fbeta_score'),
pytest.param(sk_confusion_matrix, confusion_matrix, False, id='confusion_matrix'),
pytest.param(sk_roc_curve, roc, True, id='roc'),
pytest.param(sk_precision_recall_curve, precision_recall_curve, True, id='precision_recall_curve'),
pytest.param(sk_roc_auc_score, auroc, True, id='auroc')
])
def test_against_sklearn(sklearn_metric, torch_metric, only_binary):
"""Compare PL metrics to sklearn version. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# for metrics with only_binary=False, we try out different combinations of number
# of labels in pred and target (also test binary)
# for metrics with only_binary=True, target is always binary and pred will be
# (unnormalized) class probabilities
class_comb = [(5, 2)] if only_binary else [(10, 10), (5, 10), (10, 5), (2, 2)]
for n_cls_pred, n_cls_target in class_comb:
pred = torch.randint(n_cls_pred, (300,), device=device)
target = torch.randint(n_cls_target, (300,), device=device)
sk_score = sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy())
pl_score = torch_metric(pred, target)
# if multi output
if isinstance(sk_score, tuple):
sk_score = [torch.tensor(sk_s.copy(), dtype=torch.float, device=device) for sk_s in sk_score]
for sk_s, pl_s in zip(sk_score, pl_score):
assert torch.allclose(sk_s, pl_s.float())
else:
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
assert torch.allclose(sk_score, pl_score)
@pytest.mark.parametrize('class_reduction', ['micro', 'macro', 'weighted'])
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
pytest.param(sk_precision, precision, id='precision'),
pytest.param(sk_recall, recall, id='recall'),
pytest.param(sk_f1_score, f1_score, id='f1_score'),
pytest.param(partial(sk_fbeta_score, beta=2), partial(fbeta_score, beta=2), id='fbeta_score')
])
def test_different_reduction_against_sklearn(class_reduction, sklearn_metric, torch_metric):
""" Test metrics where the class_reduction parameter have a correponding
value in sklearn """
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pred = torch.randint(10, (300,), device=device)
target = torch.randint(10, (300,), device=device)
sk_score = sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy(),
average=class_reduction)
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
pl_score = torch_metric(pred, target, class_reduction=class_reduction)
assert torch.allclose(sk_score, pl_score)
def test_onehot():
test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
expected = torch.stack([
torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]),
torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)])
])
assert test_tensor.shape == (2, 5)
assert expected.shape == (2, 10, 5)
onehot_classes = to_onehot(test_tensor, num_classes=10)
onehot_no_classes = to_onehot(test_tensor)
assert torch.allclose(onehot_classes, onehot_no_classes)
assert onehot_classes.shape == expected.shape
assert onehot_no_classes.shape == expected.shape
assert torch.allclose(expected.to(onehot_no_classes), onehot_no_classes)
assert torch.allclose(expected.to(onehot_classes), onehot_classes)
def test_to_categorical():
test_tensor = torch.stack([
torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]),
torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)])
]).to(torch.float)
expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
assert expected.shape == (2, 5)
assert test_tensor.shape == (2, 10, 5)
result = to_categorical(test_tensor)
assert result.shape == expected.shape
assert torch.allclose(result, expected.to(result.dtype))
@pytest.mark.parametrize(['pred', 'target', 'num_classes', 'expected_num_classes'], [
pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10),
pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
])
def test_get_num_classes(pred, target, num_classes, expected_num_classes):
assert get_num_classes(pred, target, num_classes) == expected_num_classes
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp',
'expected_tn', 'expected_fn', 'expected_support'], [
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2),
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2)
])
def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
tp, fp, tn, fn, sup = stat_scores(pred, target, class_index=4)
assert tp.item() == expected_tp
assert fp.item() == expected_fp
assert tn.item() == expected_tn
assert fn.item() == expected_fn
assert sup.item() == expected_support
@pytest.mark.parametrize(['pred', 'target', 'reduction', 'expected_tp', 'expected_fp',
'expected_tn', 'expected_fn', 'expected_support'], [
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 'none',
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]),
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'none',
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]),
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'sum',
torch.tensor(2), torch.tensor(2), torch.tensor(14), torch.tensor(2), torch.tensor(4)),
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'elementwise_mean',
torch.tensor(0.4), torch.tensor(0.4), torch.tensor(2.8), torch.tensor(0.4), torch.tensor(0.8))
])
def test_stat_scores_multiclass(pred, target, reduction, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target, reduction=reduction)
assert torch.allclose(torch.tensor(expected_tp).to(tp), tp)
assert torch.allclose(torch.tensor(expected_fp).to(fp), fp)
assert torch.allclose(torch.tensor(expected_tn).to(tn), tn)
assert torch.allclose(torch.tensor(expected_fn).to(fn), fn)
assert torch.allclose(torch.tensor(expected_support).to(sup), sup)
def test_multilabel_accuracy():
# Dense label indicator matrix format
y1 = torch.tensor([[0, 1, 1], [1, 0, 1]])
y2 = torch.tensor([[0, 0, 1], [1, 0, 1]])
assert torch.allclose(accuracy(y1, y2, class_reduction='none'), torch.tensor([2 / 3, 1.]))
assert torch.allclose(accuracy(y1, y1, class_reduction='none'), torch.tensor([1., 1.]))
assert torch.allclose(accuracy(y2, y2, class_reduction='none'), torch.tensor([1., 1.]))
assert torch.allclose(accuracy(y2, torch.logical_not(y2), class_reduction='none'), torch.tensor([0., 0.]))
assert torch.allclose(accuracy(y1, torch.logical_not(y1), class_reduction='none'), torch.tensor([0., 0.]))
# num_classes does not match extracted number from input we expect a warning
with pytest.warns(RuntimeWarning,
match=r'You have set .* number of classes which is'
r' different from predicted (.*) and'
r' target (.*) number of classes'):
_ = accuracy(y2, torch.zeros_like(y2), num_classes=3)
def test_accuracy():
pred = torch.tensor([0, 1, 2, 3])
target = torch.tensor([0, 1, 2, 2])
acc = accuracy(pred, target)
assert acc.item() == 0.75
pred = torch.tensor([0, 1, 2, 2])
target = torch.tensor([0, 1, 1, 3])
acc = accuracy(pred, target)
assert acc.item() == 0.50
def test_confusion_matrix():
target = (torch.arange(120) % 3).view(-1, 1)
pred = target.clone()
cm = confusion_matrix(pred, target, normalize=True)
assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]))
pred = torch.zeros_like(pred)
cm = confusion_matrix(pred, target, normalize=True)
assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]]))
target = torch.LongTensor([0, 0, 0, 0, 0])
pred = target.clone()
cm = confusion_matrix(pred, target, normalize=False, num_classes=3)
assert torch.allclose(cm, torch.tensor([[5., 0., 0.], [0., 0., 0.], [0., 0., 0.]]))
# Example taken from https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
target = torch.LongTensor([0] * 13 + [1] * 16 + [2] * 9)
pred = torch.LongTensor([0] * 13 + [1] * 10 + [2] * 15)
cm = confusion_matrix(pred, target, normalize=False, num_classes=3)
assert torch.allclose(cm, torch.tensor([[13., 0., 0.], [0., 10., 6.], [0., 0., 9.]]))
to_compare = cm / torch.tensor([[13.], [16.], [9.]])
cm = confusion_matrix(pred, target, normalize=True, num_classes=3)
assert torch.allclose(cm, to_compare)
@pytest.mark.parametrize(['pred', 'target', 'expected_prec', 'expected_rec'], [
pytest.param(torch.tensor([1., 0., 1., 0.]), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]),
pytest.param(to_onehot(torch.tensor([1., 0., 1., 0.])), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5])
])
def test_precision_recall(pred, target, expected_prec, expected_rec):
prec = precision(pred, target, class_reduction='none')
rec = recall(pred, target, class_reduction='none')
assert torch.allclose(torch.tensor(expected_prec).to(prec), prec)
assert torch.allclose(torch.tensor(expected_rec).to(rec), rec)
@pytest.mark.parametrize(['pred', 'target', 'beta', 'exp_score'], [
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 0.5, [0.5, 0.5]),
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 1, [0.5, 0.5]),
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 2, [0.5, 0.5]),
])
def test_fbeta_score(pred, target, beta, exp_score):
score = fbeta_score(torch.tensor(pred), torch.tensor(target), beta, class_reduction='none')
assert torch.allclose(score, torch.tensor(exp_score))
score = fbeta_score(to_onehot(torch.tensor(pred)), torch.tensor(target), beta, class_reduction='none')
assert torch.allclose(score, torch.tensor(exp_score))
@pytest.mark.parametrize(['pred', 'target', 'exp_score'], [
pytest.param([0., 0., 0., 0.], [1., 1., 1., 1.], [0.0, 0.0]),
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], [0.5, 0.5]),
pytest.param([1., 0., 1., 0.], [1., 0., 1., 0.], [1.0, 1.0]),
])
def test_f1_score(pred, target, exp_score):
score = f1_score(torch.tensor(pred), torch.tensor(target), class_reduction='none')
assert torch.allclose(score, torch.tensor(exp_score))
score = f1_score(to_onehot(torch.tensor(pred)), torch.tensor(target), class_reduction='none')
assert torch.allclose(score, torch.tensor(exp_score))
@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [
pytest.param(1, 1., 42),
pytest.param(None, 1., 42),
])
def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
# TODO: move back the pred and target to test func arguments
# if you fix the array inside the function, you'd also have fix the shape,
# because when the array changes, you also have to fix the shape
seed_everything(0)
pred = torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100
target = torch.tensor([0, 1] * 50, dtype=torch.int)
if sample_weight is not None:
sample_weight = torch.ones_like(pred) * sample_weight
fps, tps, thresh = _binary_clf_curve(pred, target, sample_weight, pos_label)
assert isinstance(tps, torch.Tensor)
assert isinstance(fps, torch.Tensor)
assert isinstance(thresh, torch.Tensor)
assert tps.shape == (exp_shape,)
assert fps.shape == (exp_shape,)
assert thresh.shape == (exp_shape,)
@pytest.mark.parametrize(['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], [
pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4])
])
def test_pr_curve(pred, target, expected_p, expected_r, expected_t):
p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target))
assert p.size() == r.size()
assert p.size(0) == t.size(0) + 1
assert torch.allclose(p, torch.tensor(expected_p).to(p))
assert torch.allclose(r, torch.tensor(expected_r).to(r))
assert torch.allclose(t, torch.tensor(expected_t).to(t))
@pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [
pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]),
pytest.param([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]),
pytest.param([1, 1], [1, 0], [0, 1], [0, 1]),
pytest.param([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]),
pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]),
])
def test_roc_curve(pred, target, expected_tpr, expected_fpr):
fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target))
assert fpr.shape == tpr.shape
assert fpr.size(0) == thresh.size(0)
assert torch.allclose(fpr, torch.tensor(expected_fpr).to(fpr))
assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr))
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
pytest.param([0, 1, 0, 1], [0, 1, 0, 1], 1.),
pytest.param([1, 1, 0, 0], [0, 0, 1, 1], 0.),
pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.5),
pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 1.),
pytest.param([0.5, 0.5, 0.5, 0.5], [1, 1, 0, 0], 0.5),
])
def test_auroc(pred, target, expected):
score = auroc(torch.tensor(pred), torch.tensor(target)).item()
assert score == expected
@pytest.mark.parametrize(['x', 'y', 'expected'], [
pytest.param([0, 1], [0, 1], 0.5),
pytest.param([1, 0], [0, 1], 0.5),
pytest.param([1, 0, 0], [0, 1, 1], 0.5),
pytest.param([0, 1], [1, 1], 1),
pytest.param([0, 0.5, 1], [0, 0.5, 1], 0.5),
])
def test_auc(x, y, expected):
# Test Area Under Curve (AUC) computation
assert auc(torch.tensor(x), torch.tensor(y)) == expected
@pytest.mark.parametrize(['scores', 'target', 'expected_score'], [
# Check the average_precision_score of a constant predictor is
# the TPR
# Generate a dataset with 25% of positives
# And a constant score
# The precision is then the fraction of positive whatever the recall
# is, as there is only one threshold:
pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25),
# With threshold 0.8 : 1 TP and 2 TN and one FN
pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75),
])
def test_average_precision(scores, target, expected_score):
assert average_precision(scores, target) == expected_score
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.),
pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.),
pytest.param([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3),
pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.),
])
def test_dice_score(pred, target, expected):
score = dice_score(torch.tensor(pred), torch.tensor(target))
assert score == expected
@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [
pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])),
pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])),
pytest.param(False, 'none', 0, torch.Tensor([1, 1])),
pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])),
pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])),
pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])),
])
def test_iou(half_ones, reduction, ignore_index, expected):
pred = (torch.arange(120) % 3).view(-1, 1)
target = (torch.arange(120) % 3).view(-1, 1)
if half_ones:
pred[:60] = 1
iou_val = iou(
pred=pred,
target=target,
ignore_index=ignore_index,
reduction=reduction,
)
assert torch.allclose(iou_val, expected, atol=1e-9)
@pytest.mark.parametrize('metric', [auroc])
def test_error_on_multiclass_input(metric):
""" check that these metrics raise an error if they are used for multiclass problems """
pred = torch.randint(0, 10, (100, ))
target = torch.randint(0, 10, (100, ))
with pytest.raises(ValueError, match="AUROC metric is meant for binary classification"):
_ = metric(pred, target)
# TODO: When the jaccard_score of the sklearn version we use accepts `zero_division` (see
# https://github.com/scikit-learn/scikit-learn/pull/17866), consider adding a test here against our
# `absent_score`.
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'absent_score', 'num_classes', 'expected'], [
# Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid
# scores the function can return ([0., 1.] range, inclusive).
# 2 classes, class 0 is correct everywhere, class 1 is absent.
pytest.param([0], [0], None, -1., 2, [1., -1.]),
pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]),
# absent_score not applied if only class 0 is present and it's the only class.
pytest.param([0], [0], None, -1., 1, [1.]),
# 2 classes, class 1 is correct everywhere, class 0 is absent.
pytest.param([1], [1], None, -1., 2, [-1., 1.]),
pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]),
# When 0 index ignored, class 0 does not get a score (not even the absent_score).
pytest.param([1], [1], 0, -1., 2, [1.0]),
# 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score.
pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]),
pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]),
# 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score.
pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]),
pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]),
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class
# 2 is absent.
pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]),
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class
# 2 is absent.
pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]),
# Sanity checks with absent_score of 1.0.
pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]),
pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]),
])
def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected):
iou_val = iou(
pred=torch.tensor(pred),
target=torch.tensor(target),
ignore_index=ignore_index,
absent_score=absent_score,
num_classes=num_classes,
reduction='none',
)
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))
# example data taken from
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], [
# Ignoring an index outside of [0, num_classes-1] should have no effect.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]),
# Ignoring a valid index drops only that index from the result.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]),
# When reducing to mean or sum, the ignored index does not contribute to the output.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]),
])
def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected):
iou_val = iou(
pred=torch.tensor(pred),
target=torch.tensor(target),
ignore_index=ignore_index,
num_classes=num_classes,
reduction=reduction,
)
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))

View File

@ -1,66 +0,0 @@
import pytest
import torch
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu
from pytorch_lightning.metrics.functional.nlp import bleu_score
# example taken from
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.sentence_bleu
HYPOTHESIS1 = tuple(
"It is a guide to action which ensures that the military always obeys the commands of the party".split()
)
REFERENCE1 = tuple("It is a guide to action that ensures that the military will forever heed Party commands".split())
REFERENCE2 = tuple(
"It is a guiding principle which makes the military forces always being under the command of the Party".split()
)
REFERENCE3 = tuple("It is the practical guide for the army always to heed the directions of the party".split())
# example taken from
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu
HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split()
HYP2 = "he read the book because he was interested in world history".split()
REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split()
REF1B = "It is a guiding principle which makes the military force always being under the command of the Party".split()
REF1C = "It is the practical guide for the army always to heed the directions of the party".split()
REF2A = "he was interested in world history because he read the book".split()
LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]]
HYPOTHESES = [HYP1, HYP2]
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction
smooth_func = SmoothingFunction().method2
@pytest.mark.parametrize(
["weights", "n_gram", "smooth_func", "smooth"],
[
pytest.param([1], 1, None, False),
pytest.param([0.5, 0.5], 2, smooth_func, True),
pytest.param([0.333333, 0.333333, 0.333333], 3, None, False),
pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True),
],
)
def test_bleu_score(weights, n_gram, smooth_func, smooth):
nltk_output = sentence_bleu(
[REFERENCE1, REFERENCE2, REFERENCE3], HYPOTHESIS1, weights=weights, smoothing_function=smooth_func
)
pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth)
assert torch.allclose(pl_output, torch.tensor(nltk_output))
nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func)
pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth)
assert torch.allclose(pl_output, torch.tensor(nltk_output))
def test_bleu_empty():
hyp = [[]]
ref = [[[]]]
assert bleu_score(hyp, ref) == torch.tensor(0.0)
def test_no_4_gram():
hyps = [["My", "full", "pytorch-lightning"]]
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]]
assert bleu_score(hyps, refs) == torch.tensor(0.0)

View File

@ -1,30 +0,0 @@
import pytest
import torch
from pytorch_lightning.metrics.functional.reduction import reduce, class_reduce
def test_reduce():
start_tensor = torch.rand(50, 40, 30)
assert torch.allclose(reduce(start_tensor, 'elementwise_mean'), torch.mean(start_tensor))
assert torch.allclose(reduce(start_tensor, 'sum'), torch.sum(start_tensor))
assert torch.allclose(reduce(start_tensor, 'none'), start_tensor)
with pytest.raises(ValueError):
reduce(start_tensor, 'error_reduction')
def test_class_reduce():
num = torch.randint(1, 10, (100,)).float()
denom = torch.randint(10, 20, (100,)).float()
weights = torch.randint(1, 100, (100,)).float()
assert torch.allclose(class_reduce(num, denom, weights, 'micro'),
torch.sum(num) / torch.sum(denom))
assert torch.allclose(class_reduce(num, denom, weights, 'macro'),
torch.mean(num / denom))
assert torch.allclose(class_reduce(num, denom, weights, 'weighted'),
torch.sum(num / denom * (weights / torch.sum(weights))))
assert torch.allclose(class_reduce(num, denom, weights, 'none'),
num / denom)

View File

@ -1,175 +0,0 @@
import numpy as np
import pytest
import torch
from functools import partial
from math import sqrt
from skimage.metrics import (
peak_signal_noise_ratio as ski_psnr,
structural_similarity as ski_ssim
)
from sklearn.metrics import (
mean_absolute_error as mae_sk,
mean_squared_error as mse_sk,
mean_squared_log_error as msle_sk
)
from pytorch_lightning.metrics.functional import (
mae,
mse,
psnr,
rmse,
rmsle,
ssim
)
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
pytest.param(mae_sk, mae, id='mean_absolute_error'),
pytest.param(mse_sk, mse, id='mean_squared_error'),
pytest.param(partial(mse_sk, squared=False), rmse, id='root_mean_squared_error'),
pytest.param(lambda x, y: sqrt(msle_sk(x, y)), rmsle, id='root_mean_squared_log_error')
])
def test_against_sklearn(sklearn_metric, torch_metric):
"""Compare PL metrics to sklearn version."""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# iterate over different label counts in predictions and target
pred = torch.rand(300, device=device)
target = torch.rand(300, device=device)
sk_score = sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy())
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
pl_score = torch_metric(pred, target)
assert torch.allclose(sk_score, pl_score)
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.25),
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 3.0),
])
def test_mse(pred, target, expected):
score = mse(torch.tensor(pred), torch.tensor(target))
assert score.item() == expected
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0),
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.5),
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 1.7321),
])
def test_rmse(pred, target, expected):
score = rmse(torch.tensor(pred), torch.tensor(target))
assert torch.allclose(score, torch.tensor(expected), atol=1e-3)
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0),
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.25),
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 1.5),
])
def test_mae(pred, target, expected):
score = mae(torch.tensor(pred), torch.tensor(target))
assert score.item() == expected
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0),
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.1438),
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 0.5330),
])
def test_rmsle(pred, target, expected):
score = rmsle(torch.tensor(pred), torch.tensor(target))
assert torch.allclose(score, torch.tensor(expected), atol=1e-3)
@pytest.mark.parametrize(['pred', 'target'], [
pytest.param([0., 1., 2., 3.], [0., 1., 2., 3.]),
pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.]),
pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.]),
])
def test_psnr_with_skimage(pred, target):
score = psnr(pred=torch.tensor(pred),
target=torch.tensor(target), data_range=3)
sk_score = ski_psnr(np.array(pred), np.array(target), data_range=3)
assert torch.allclose(score, torch.tensor(sk_score, dtype=torch.float), atol=1e-3)
@pytest.mark.parametrize(['pred', 'target'], [
pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.]),
pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.]),
])
def test_psnr_base_e_wider_range(pred, target):
score = psnr(pred=torch.tensor(pred),
target=torch.tensor(target),
data_range=4,
base=2.718281828459045)
sk_score = ski_psnr(np.array(pred), np.array(target), data_range=4) * np.log(10)
assert torch.allclose(score, torch.tensor(sk_score, dtype=torch.float32), atol=1e-3)
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
pytest.param(ski_psnr, psnr, id='peak_signal_noise_ratio')
])
def test_psnr_against_sklearn(sklearn_metric, torch_metric):
"""Compare PL metrics to sklearn version."""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]:
pred = torch.randint(n_cls_pred, (500,), device=device, dtype=torch.float)
target = torch.randint(n_cls_target, (500,), device=device, dtype=torch.float)
sk_score = sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy(),
data_range=n_cls_target)
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
pl_score = torch_metric(pred, target, data_range=n_cls_target)
assert torch.allclose(sk_score, pl_score)
@pytest.mark.parametrize(['size', 'channel', 'coef', 'multichannel'], [
pytest.param(16, 1, 0.9, False),
pytest.param(32, 3, 0.8, True),
pytest.param(48, 4, 0.7, True),
pytest.param(64, 5, 0.6, True)
])
def test_ssim(size, channel, coef, multichannel):
device = "cuda" if torch.cuda.is_available() else "cpu"
pred = torch.rand(size, channel, size, size, device=device)
target = pred * coef
ssim_idx = ssim(pred, target, data_range=1.0)
np_pred = pred.permute(0, 2, 3, 1).cpu().numpy()
if multichannel is False:
np_pred = np_pred[:, :, :, 0]
np_target = np.multiply(np_pred, coef)
sk_ssim_idx = ski_ssim(
np_pred, np_target, win_size=11, multichannel=multichannel, gaussian_weights=True, data_range=1.0
)
assert torch.allclose(ssim_idx, torch.tensor(sk_ssim_idx, dtype=torch.float, device=device), atol=1e-4)
ssim_idx = ssim(pred, pred)
assert torch.allclose(ssim_idx, torch.tensor(1.0, device=device))
@pytest.mark.parametrize(['pred', 'target', 'kernel', 'sigma'], [
pytest.param([1, 1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # shape
pytest.param([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape)
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma)
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma)
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma)
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input
])
def test_ssim_invalid_inputs(pred, target, kernel, sigma):
pred_t = torch.rand(pred)
target_t = torch.rand(target, dtype=torch.float64)
with pytest.raises(TypeError):
ssim(pred_t, target_t)
pred = torch.rand(pred)
target = torch.rand(target)
with pytest.raises(ValueError):
ssim(pred, target, kernel, sigma)

View File

@ -1,35 +0,0 @@
import pytest
import torch
from sklearn.metrics import pairwise
from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity
@pytest.mark.parametrize('similarity', ['cosine', 'dot'])
@pytest.mark.parametrize('reduction', ['none', 'mean', 'sum'])
def test_against_sklearn(similarity, reduction):
"""Compare PL metrics to sklearn version."""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch = torch.randn(5, 10, device=device) # 100 samples in 10 dimensions
pl_dist = embedding_similarity(batch, similarity=similarity,
reduction=reduction, zero_diagonal=False)
def sklearn_embedding_distance(batch, similarity, reduction):
metric_func = {'cosine': pairwise.cosine_similarity,
'dot': pairwise.linear_kernel}[similarity]
dist = metric_func(batch, batch)
if reduction == 'mean':
return dist.mean(axis=-1)
if reduction == 'sum':
return dist.sum(axis=-1)
return dist
sk_dist = sklearn_embedding_distance(batch.cpu().detach().numpy(),
similarity=similarity, reduction=reduction)
sk_dist = torch.tensor(sk_dist, dtype=torch.float, device=device)
assert torch.allclose(sk_dist, pl_dist)

View File

View File

@ -0,0 +1,57 @@
import torch
import pytest
from collections import namedtuple
from functools import partial
from pytorch_lightning.metrics.regression import MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError
from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_squared_log_error
from tests.metrics.utils import compute_batch, setup_ddp
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE
num_targets = 5
Input = namedtuple('Input', ["preds", "target"])
_single_target_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.rand(NUM_BATCHES, BATCH_SIZE),
)
_multi_target_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
)
def _single_target_sk_metric(preds, target, sk_fn=mean_squared_error):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_fn(sk_preds, sk_target)
def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error):
sk_preds = preds.view(-1, num_targets).numpy()
sk_target = target.view(-1, num_targets).numpy()
return sk_fn(sk_preds, sk_target)
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("ddp_sync_on_step", [True, False])
@pytest.mark.parametrize(
"preds, target, sk_metric",
[
(_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric),
(_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric),
],
)
@pytest.mark.parametrize(
"metric_class, sk_fn",
[
(MeanSquaredError, mean_squared_error),
(MeanAbsoluteError, mean_absolute_error),
(MeanSquaredLogError, mean_squared_log_error),
],
)
def test_mean_error(ddp, ddp_sync_on_step, preds, target, sk_metric, metric_class, sk_fn):
compute_batch(preds, target, metric_class, partial(sk_metric, sk_fn=sk_fn), ddp_sync_on_step, ddp)

View File

@ -1,297 +0,0 @@
import pytest
import sys
from collections import namedtuple
from functools import partial
import math
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import numpy as np
from tests.base import EvalModelTemplate
from pytorch_lightning import Trainer
import tests.base.develop_utils as tutils
from pytorch_lightning.metrics import (
Accuracy,
ConfusionMatrix,
PrecisionRecallCurve,
Precision,
Recall,
AveragePrecision,
AUROC,
FBeta,
F1,
ROC,
MulticlassROC,
MulticlassPrecisionRecallCurve,
DiceCoefficient,
IoU,
MAE,
MSE,
RMSE,
RMSLE,
PSNR,
SSIM,
)
from sklearn.metrics import (
accuracy_score,
confusion_matrix,
precision_recall_curve,
precision_score,
recall_score,
average_precision_score,
roc_auc_score,
fbeta_score,
f1_score,
roc_curve,
jaccard_score,
mean_squared_error,
mean_absolute_error,
mean_squared_log_error
)
from skimage.metrics import (
peak_signal_noise_ratio,
structural_similarity
)
# example structure
TestCase = namedtuple('example', ['name', 'lightning_metric', 'comparing_metric', 'test_input'])
# setup some standard testcases
NB_SAMPLES = 200
multiclass_example = [(torch.randint(10, (NB_SAMPLES,)), torch.randint(10, (NB_SAMPLES,)))]
binary_example = [(torch.randint(2, (NB_SAMPLES,)), torch.randint(2, (NB_SAMPLES,)))]
multiclass_and_binary_example = [*multiclass_example, *binary_example]
binary_example_logits = (torch.randint(2, (NB_SAMPLES,)), torch.randint(5, (NB_SAMPLES,)))
multiclass_example_probs = (torch.randint(10, (NB_SAMPLES,)), torch.randn((NB_SAMPLES, 10)).softmax(-1))
regression_example = [(torch.rand((NB_SAMPLES,)), torch.rand((NB_SAMPLES,)))]
# construct additional test functions
def root_mean_squared_error(x, y):
return math.sqrt(mean_squared_error(x, y))
def root_mean_squared_log_error(x, y):
return math.sqrt(mean_squared_log_error(x, y))
# Define testcases
# TODO: update remaining metrics and uncomment the corresponding test cases
TESTS = [
TestCase('accuracy',
Accuracy,
accuracy_score,
multiclass_and_binary_example),
TestCase('confusion matrix without normalize',
ConfusionMatrix,
confusion_matrix,
multiclass_and_binary_example),
TestCase('confusion matrix with normalize',
partial(ConfusionMatrix, normalize=True),
partial(confusion_matrix, normalize='true'),
multiclass_and_binary_example),
# TestCase('precision recall curve',
# PrecisionRecallCurve,
# precision_recall_curve,
# binary_example),
TestCase('precision',
Precision,
partial(precision_score, average='micro'),
multiclass_and_binary_example),
TestCase('recall',
Recall,
partial(recall_score, average='micro'),
multiclass_and_binary_example),
# TestCase('average_precision',
# AveragePrecision,
# average_precision_score,
# binary_example),
# TestCase('auroc',
# AUROC,
# roc_auc_score,
# binary_example),
TestCase('f beta',
partial(FBeta, beta=2),
partial(fbeta_score, average='micro', beta=2),
multiclass_and_binary_example),
TestCase('f1',
F1,
partial(f1_score, average='micro'),
multiclass_and_binary_example),
# TestCase('roc',
# ROC,
# roc_curve,
# binary_example),
# TestCase('multiclass roc',
# MulticlassROC,
# multiclass_roc,
# binary_example),
# TestCase('multiclass precision recall curve',
# MulticlassPrecisionRecallCurve,
# multiclass_precision_recall_curve,
# binary_example),
# TestCase('dice coefficient',
# DiceCoefficient,
# partial(f1_score, average='micro'),
# multiclass_and_binary_example),
# TestCase('intersection over union',
# IoU,
# partial(jaccard_score, average='macro'),
# binary_example),
TestCase('mean squared error',
MSE,
mean_squared_error,
regression_example),
TestCase('root mean squared error',
RMSE,
root_mean_squared_error,
regression_example),
TestCase('mean absolute error',
MAE,
mean_absolute_error,
regression_example),
TestCase('root mean squared log error',
RMSLE,
root_mean_squared_log_error,
regression_example),
TestCase('peak signal-to-noise ratio',
partial(PSNR, data_range=10),
partial(peak_signal_noise_ratio, data_range=10),
regression_example),
# TestCase('structual similarity index measure',
# SSIM,
# structural_similarity,
# regression_example)
]
# Utility test functions
def _idsfn(test):
""" Return id for current example being tested """
return test.name
def _setup_ddp(rank, worldsize):
""" setup ddp enviroment for testing """
import os
os.environ['MASTER_ADDR'] = 'localhost'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
def comparing_fn(lightning_val, comparing_val, rtol=1e-03, atol=1e-08):
""" function for comparing output, both multi and single output"""
# multi output
if isinstance(comparing_val, tuple):
for l_score, c_score in zip(lightning_val, comparing_val):
assert np.allclose(l_score.numpy(), c_score, rtol, atol)
else: # single output
assert np.allclose(lightning_val.numpy(), comparing_val, rtol, atol)
# ===== Tests start here =====
def _test_ddp_single_batch(rank, worldsize, lightning_metric, comparing_metric, test_inputs):
""" ddp testing function, divide test_inputs equally between all processes """
_setup_ddp(rank, worldsize)
# Setup metric for ddp
lightning_metric = lightning_metric()
for test_input in test_inputs:
# rank 0 receives sample 0,2,4,...
# rank 1 receives sample 1,3,5,...
lightning_val = lightning_metric(*[ti[rank::2] for ti in test_input])
comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)])
comparing_fn(lightning_val, comparing_val)
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.parametrize("test", TESTS, ids=_idsfn)
def test_ddp(test):
"""Make sure that metrics are correctly sync and reduced in DDP mode"""
tutils.reset_seed()
tutils.set_random_master_port()
worldsize = 2
mp.spawn(_test_ddp_single_batch,
args=(worldsize,
test.lightning_metric,
test.comparing_metric,
test.test_input),
nprocs=worldsize)
@pytest.mark.parametrize("test", TESTS, ids=_idsfn)
def test_multi_batch(test):
""" test that aggregation works for multiple batches """
lightning_metric = test.lightning_metric()
comparing_metric = test.comparing_metric
for test_input in test.test_input:
for i in range(2): # for lightning device in 2 artificially batches
# first batch consist of samples 0,2,4,...
# second batch consist of samples 1,3,5,...
_ = lightning_metric(*[ti[i::2] for ti in test_input])
lightning_val = lightning_metric.aggregated
comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)])
comparing_fn(lightning_val, comparing_val)
@pytest.mark.parametrize("test", TESTS, ids=_idsfn)
def test_multi_batch_unequal_sizes(test):
""" test that aggregation works for multiple batches with uneven sizes """
lightning_metric = test.lightning_metric()
comparing_metric = test.comparing_metric
for test_input in test.test_input:
for i in range(2): # for lightning device in 2 artificially batches
if i == 0: # allocate 3/4 of data to the first batch
_ = lightning_metric(*[ti[:int(3 / 4 * len(ti))] for ti in test_input])
else:
_ = lightning_metric(*[ti[int(3 / 4 * len(ti)):] for ti in test_input])
lightning_val = lightning_metric.aggregated
comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)])
comparing_fn(lightning_val, comparing_val)
def _test_ddp_multi_batch(rank, worldsize, lightning_metric, comparing_metric, test_inputs):
""" ddp testing function, test that metric works with aggregation over multiple
devices and multiple batches """
_setup_ddp(rank, worldsize)
# Setup metric for ddp
lightning_metric = lightning_metric()
for test_input in test_inputs:
for i in range(2): # artificially divide samples between batches and processes
# rank 0, batch 0 consist of samples 0,4,8,...
# rank 0, batch 1 consist of samples 1,5,9,...
# rank 1, batch 0 consist of samples 2,6,10,...
# rank 1, batch 1 consist of samples 3,7,11,...
_ = lightning_metric(*[ti[i + worldsize * rank::4] for ti in test_input])
lightning_val = lightning_metric.aggregated
comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)])
comparing_fn(lightning_val, comparing_val)
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.parametrize("test", TESTS, ids=_idsfn)
def test_ddp_multi_batch(test):
""" test that aggregation works fine with in DDP mode and multiple batches """
tutils.reset_seed()
tutils.set_random_master_port()
worldsize = 2
mp.spawn(_test_ddp_multi_batch,
args=(worldsize,
test.lightning_metric,
test.comparing_metric,
test.test_input),
nprocs=worldsize)

View File

@ -1,237 +0,0 @@
# NOTE: This file only tests if modules with arguments are running fine.
# The actual metric implementation is tested in functional/test_classification.py
# Especially reduction and reducing across processes won't be tested here!
import pytest
import torch
from pytorch_lightning.metrics.classification import (
Accuracy,
ConfusionMatrix,
PrecisionRecallCurve,
Precision,
Recall,
AveragePrecision,
AUROC,
FBeta,
F1,
ROC,
MulticlassROC,
MulticlassPrecisionRecallCurve,
DiceCoefficient,
IoU,
)
@pytest.fixture
def random():
torch.manual_seed(0)
@pytest.mark.parametrize('num_classes', [1, None])
def test_accuracy(num_classes):
acc = Accuracy(num_classes=num_classes)
assert acc.name == 'accuracy'
result = acc(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]),
target=torch.tensor([[0, 0, 1], [1, 0, 1]]))
assert isinstance(result, torch.Tensor)
@pytest.mark.parametrize(['normalize', 'num_classes'], [
pytest.param(False, None),
pytest.param(True, None),
pytest.param(False, 3)
])
def test_confusion_matrix(normalize, num_classes):
conf_matrix = ConfusionMatrix(normalize=normalize, num_classes=num_classes)
assert conf_matrix.name == 'confusion_matrix'
target = (torch.arange(120) % 3).view(-1, 1)
pred = target.clone()
cm = conf_matrix(pred, target)
assert isinstance(cm, torch.Tensor)
@pytest.mark.parametrize(['normalize', 'num_classes'], [
pytest.param(True, 3)
])
def test_confusion_matrix_norm(normalize, num_classes):
""" test that user is warned if confusion matrix contains nans that are changed to zeros"""
conf_matrix = ConfusionMatrix(normalize=normalize, num_classes=num_classes)
assert conf_matrix.name == 'confusion_matrix'
with pytest.warns(UserWarning, match='6 nan values found in confusion matrix have been replaced with zeros.'):
target = torch.LongTensor([0] * 5)
pred = target.clone()
cm = conf_matrix(pred, target)
assert isinstance(cm, torch.Tensor)
@pytest.mark.parametrize('pos_label', [1, 2.])
def test_precision_recall(pos_label):
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1])
pr_curve = PrecisionRecallCurve(pos_label=pos_label)
assert pr_curve.name == 'precision_recall_curve'
pr = pr_curve(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
assert isinstance(pr, tuple)
assert len(pr) == 3
for tmp in pr:
assert isinstance(tmp, torch.Tensor)
@pytest.mark.parametrize('num_classes', [1, None])
def test_precision(num_classes):
precision = Precision(num_classes=num_classes)
assert precision.name == 'precision'
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1])
prec = precision(pred=pred, target=target)
assert isinstance(prec, torch.Tensor)
@pytest.mark.parametrize('num_classes', [1, None])
def test_recall(num_classes):
recall = Recall(num_classes=num_classes)
assert recall.name == 'recall'
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1])
rec = recall(pred=pred, target=target)
assert isinstance(rec, torch.Tensor)
@pytest.mark.parametrize('pos_label', [1, 2])
def test_average_precision(pos_label):
avg_prec = AveragePrecision(pos_label=pos_label)
assert avg_prec.name == 'AP'
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 0, 1])
ap = avg_prec(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
assert isinstance(ap, torch.Tensor)
@pytest.mark.parametrize('pos_label', [0, 1])
def test_auroc(pos_label):
auroc = AUROC(pos_label=pos_label)
assert auroc.name == 'auroc'
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 1, 0, 1])
area = auroc(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
assert isinstance(area, torch.Tensor)
@pytest.mark.parametrize(['beta', 'num_classes'], [
pytest.param(0., 1),
pytest.param(0.5, 1),
pytest.param(1., 1),
pytest.param(2., 1),
pytest.param(0., None),
pytest.param(0.5, None),
pytest.param(1., None),
pytest.param(2., None)
])
def test_fbeta(beta, num_classes):
fbeta = FBeta(beta=beta, num_classes=num_classes)
assert fbeta.name == 'fbeta'
score = fbeta(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]),
target=torch.tensor([[0, 0, 1], [1, 0, 1]]))
assert isinstance(score, torch.Tensor)
@pytest.mark.parametrize('num_classes', [1, None])
def test_f1(num_classes):
f1 = F1(num_classes=num_classes)
assert f1.name == 'f1'
score = f1(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]),
target=torch.tensor([[0, 0, 1], [1, 0, 1]]))
assert isinstance(score, torch.Tensor)
@pytest.mark.parametrize('pos_label', [1, 2])
def test_roc(pos_label):
roc = ROC(pos_label=pos_label)
assert roc.name == 'roc'
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 4, 3])
res = roc(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
assert isinstance(res, tuple)
assert len(res) == 3
for tmp in res:
assert isinstance(tmp, torch.Tensor)
@pytest.mark.parametrize('num_classes', [4, None])
def test_multiclass_roc(num_classes):
pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
[0.05, 0.85, 0.05, 0.05],
[0.05, 0.05, 0.85, 0.05],
[0.05, 0.05, 0.05, 0.85]])
target = torch.tensor([0, 1, 3, 2])
multi_roc = MulticlassROC(num_classes=num_classes)
assert multi_roc.name == 'multiclass_roc'
res = multi_roc(pred, target)
assert isinstance(res, tuple)
if num_classes is not None:
assert len(res) == num_classes
for tmp in res:
assert isinstance(tmp, tuple)
assert len(tmp) == 3
for _tmp in tmp:
assert isinstance(_tmp, torch.Tensor)
@pytest.mark.parametrize('num_classes', [4, None])
def test_multiclass_pr(num_classes):
pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
[0.05, 0.85, 0.05, 0.05],
[0.05, 0.05, 0.85, 0.05],
[0.05, 0.05, 0.05, 0.85]])
target = torch.tensor([0, 1, 3, 2])
multi_pr = MulticlassPrecisionRecallCurve(num_classes=num_classes)
assert multi_pr.name == 'multiclass_precision_recall_curve'
pr = multi_pr(pred, target)
assert isinstance(pr, tuple)
if num_classes is not None:
assert len(pr) == num_classes
for tmp in pr:
assert isinstance(tmp, tuple)
assert len(tmp) == 3
for _tmp in tmp:
assert isinstance(_tmp, torch.Tensor)
@pytest.mark.parametrize('include_background', [True, False])
def test_dice_coefficient(include_background):
dice_coeff = DiceCoefficient(include_background=include_background)
assert dice_coeff.name == 'dice'
dice = dice_coeff(torch.randint(0, 1, (10, 25, 25)),
torch.randint(0, 1, (10, 25, 25)))
assert isinstance(dice, torch.Tensor)
@pytest.mark.parametrize('ignore_index', [0, 1, None])
def test_iou(ignore_index):
iou = IoU(ignore_index=ignore_index)
assert iou.name == 'iou'
score = iou(torch.randint(0, 1, (10, 25, 25)),
torch.randint(0, 1, (10, 25, 25)))
assert isinstance(score, torch.Tensor)

View File

@ -1,265 +0,0 @@
import sys
import numpy as np
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import tests.base.develop_utils as tutils
from pytorch_lightning.metrics.converters import (
_apply_to_inputs,
_apply_to_outputs,
convert_to_tensor,
convert_to_numpy,
_numpy_metric_conversion,
_tensor_metric_conversion,
sync_ddp_if_available,
gather_all_tensors_if_available,
tensor_metric,
numpy_metric
)
def test_apply_to_inputs():
def apply_fn(inputs, factor):
if isinstance(inputs, (float, int)):
return inputs * factor
elif isinstance(inputs, dict):
return {k: apply_fn(v, factor) for k, v in inputs.items()}
elif isinstance(inputs, (tuple, list)):
return [apply_fn(x, factor) for x in inputs]
@_apply_to_inputs(apply_fn, factor=2.)
def test_fn(*args, **kwargs):
return args, kwargs
for args in [[], [1., 2.]]:
for kwargs in [{}, {'a': 1., 'b': 2.}]:
result_args, result_kwargs = test_fn(*args, **kwargs)
assert isinstance(result_args, (list, tuple))
assert isinstance(result_kwargs, dict)
assert len(result_args) == len(args)
assert len(result_kwargs) == len(kwargs)
assert all([k in result_kwargs for k in kwargs.keys()])
for arg, result_arg in zip(args, result_args):
assert arg * 2. == result_arg
for key in kwargs.keys():
arg = kwargs[key]
result_arg = result_kwargs[key]
assert arg * 2. == result_arg
def test_apply_to_outputs():
def apply_fn(inputs, additional_str):
return str(inputs) + additional_str
@_apply_to_outputs(apply_fn, additional_str='_str')
def test_fn(*args, **kwargs):
return 'dummy'
assert test_fn() == 'dummy_str'
def test_convert_to_tensor():
for test_item in [1., np.array([1.])]:
result_tensor = convert_to_tensor(test_item)
assert isinstance(result_tensor, torch.Tensor)
assert result_tensor.item() == 1.
def test_convert_to_numpy():
for test_item in [1., torch.tensor([1.])]:
result = convert_to_numpy(test_item)
assert isinstance(result, np.ndarray)
assert result.item() == 1.
def test_numpy_metric_conversion():
@_numpy_metric_conversion
def numpy_test_metric(*args, **kwargs):
for arg in args:
assert isinstance(arg, np.ndarray)
for v in kwargs.values():
assert isinstance(v, np.ndarray)
return 5.
result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.)
assert isinstance(result, torch.Tensor)
assert result.item() == 5.
def test_tensor_metric_conversion():
@_tensor_metric_conversion
def tensor_test_metric(*args, **kwargs):
for arg in args:
assert isinstance(arg, torch.Tensor)
for v in kwargs.values():
assert isinstance(v, torch.Tensor)
return 5.
result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.)
assert isinstance(result, torch.Tensor)
assert result.item() == 5.
def _setup_ddp(rank, worldsize):
import os
os.environ['MASTER_ADDR'] = 'localhost'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
def _ddp_test_fn(rank, worldsize, add_offset: bool, reduction_mean=False):
_setup_ddp(rank, worldsize)
if add_offset:
tensor = torch.tensor([float(rank)])
else:
tensor = torch.tensor([1.], )
if reduction_mean:
reduced_tensor = sync_ddp_if_available(tensor, reduce_op='avg')
manual_reduction = sum([i for i in range(dist.get_world_size())]) / dist.get_world_size()
assert reduced_tensor.item() == manual_reduction
else:
reduced_tensor = sync_ddp_if_available(tensor)
assert reduced_tensor.item() == dist.get_world_size(), \
'Sync-Reduce does not work properly with DDP and Tensors'
def _ddp_test_gather_all_tensors(rank, worldsize):
_setup_ddp(rank, worldsize)
tensor = torch.tensor([rank])
gather_tensors = gather_all_tensors_if_available(tensor)
mannual_tensors = [torch.tensor([i]) for i in range(worldsize)]
for t1, t2 in zip(gather_tensors, mannual_tensors):
assert(t1.equal(t2))
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
def test_sync_reduce_ddp():
"""Make sure sync-reduce works with DDP"""
tutils.reset_seed()
tutils.set_random_master_port()
worldsize = 2
mp.spawn(_ddp_test_fn, args=(worldsize, False), nprocs=worldsize)
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
def test_sync_reduce_ddp_mean():
"""Make sure sync-reduce works with DDP"""
tutils.reset_seed()
tutils.set_random_master_port()
worldsize = 2
mp.spawn(_ddp_test_fn, args=(worldsize, True, True), nprocs=worldsize)
def test_sync_reduce_simple():
"""Make sure sync-reduce works without DDP"""
tensor = torch.tensor([1.], device='cpu')
reduced_tensor = sync_ddp_if_available(tensor)
assert torch.allclose(tensor, reduced_tensor), \
'Sync-Reduce does not work properly without DDP and Tensors'
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
def test_gather_all_tensors_ddp():
"""Make sure gather_all_tensors works with DDP"""
tutils.reset_seed()
tutils.set_random_master_port()
worldsize = 2
mp.spawn(_ddp_test_gather_all_tensors, args=(worldsize, ), nprocs=worldsize)
def _test_tensor_metric(is_ddp: bool):
@tensor_metric()
def tensor_test_metric(*args, **kwargs):
for arg in args:
assert isinstance(arg, torch.Tensor)
for v in kwargs.values():
assert isinstance(v, torch.Tensor)
return 5.
if is_ddp:
factor = dist.get_world_size()
else:
factor = 1.
result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.)
assert isinstance(result, torch.Tensor)
assert result.item() == 5. * factor
def _ddp_test_tensor_metric(rank, worldsize):
_setup_ddp(rank, worldsize)
_test_tensor_metric(True)
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
def test_tensor_metric_ddp():
tutils.reset_seed()
tutils.set_random_master_port()
world_size = 2
mp.spawn(_ddp_test_tensor_metric, args=(world_size,), nprocs=world_size)
# dist.destroy_process_group()
def test_tensor_metric_simple():
_test_tensor_metric(False)
def _test_numpy_metric(is_ddp: bool):
@numpy_metric()
def numpy_test_metric(*args, **kwargs):
for arg in args:
assert isinstance(arg, np.ndarray)
for v in kwargs.values():
assert isinstance(v, np.ndarray)
return 5.
if is_ddp:
factor = dist.get_world_size()
else:
factor = 1.
result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.)
assert isinstance(result, torch.Tensor)
assert result.item() == 5. * factor
def _ddp_test_numpy_metric(rank, worldsize):
_setup_ddp(rank, worldsize)
_test_numpy_metric(True)
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
def test_numpy_metric_ddp():
tutils.reset_seed()
tutils.set_random_master_port()
world_size = 2
mp.spawn(_ddp_test_numpy_metric, args=(world_size,), nprocs=world_size)
# dist.destroy_process_group()
def test_numpy_metric_simple():
_test_numpy_metric(False)

45
tests/metrics/test_ddp.py Normal file
View File

@ -0,0 +1,45 @@
import pytest
import torch
import os
import sys
from tests.metrics.test_metric import Dummy
from tests.metrics.utils import setup_ddp
torch.manual_seed(42)
def _test_ddp_sum(rank, worldsize):
setup_ddp(rank, worldsize)
dummy = Dummy()
dummy._reductions = {"foo": torch.sum}
dummy.foo = torch.tensor(1)
dummy._sync_dist()
assert dummy.foo == worldsize
def _test_ddp_cat(rank, worldsize):
setup_ddp(rank, worldsize)
dummy = Dummy()
dummy._reductions = {"foo": torch.cat}
dummy.foo = [torch.tensor([1])]
dummy._sync_dist()
assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1])))
def _test_ddp_sum_cat(rank, worldsize):
setup_ddp(rank, worldsize)
dummy = Dummy()
dummy._reductions = {"foo": torch.cat, "bar": torch.sum}
dummy.foo = [torch.tensor([1])]
dummy.bar = torch.tensor(1)
dummy._sync_dist()
assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1])))
assert dummy.bar == worldsize
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@pytest.mark.parametrize("process", [_test_ddp_cat, _test_ddp_sum, _test_ddp_sum_cat])
def test_ddp(process):
torch.multiprocessing.spawn(process, args=(2,), nprocs=2)

View File

@ -0,0 +1,108 @@
import pytest
import torch
from pytorch_lightning.metrics.metric import Metric
import os
import numpy as np
torch.manual_seed(42)
class Dummy(Metric):
name = "Dummy"
def __init__(self):
super().__init__()
self.add_state("x", torch.tensor(0), dist_reduce_fx=None)
def update(self):
pass
def compute(self):
pass
def test_inherit():
a = Dummy()
def test_add_state():
a = Dummy()
a.add_state("a", torch.tensor(0), "sum")
assert a._reductions["a"](torch.tensor([1, 1])) == 2
a.add_state("b", torch.tensor(0), "mean")
assert np.allclose(a._reductions["b"](torch.tensor([1.0, 2.0])).numpy(), 1.5)
a.add_state("c", torch.tensor(0), "cat")
assert a._reductions["c"]([torch.tensor([1]), torch.tensor([1])]).shape == (2,)
with pytest.raises(ValueError):
a.add_state("d1", torch.tensor(0), 'xyz')
with pytest.raises(ValueError):
a.add_state("d2", torch.tensor(0), 42)
with pytest.raises(ValueError):
a.add_state("d3", [torch.tensor(0)], 'sum')
with pytest.raises(ValueError):
a.add_state("d4", 42, 'sum')
def custom_fx(x):
return -1
a.add_state("e", torch.tensor(0), custom_fx)
assert a._reductions["e"](torch.tensor([1, 1])) == -1
def test_reset():
class A(Dummy):
pass
a = A()
assert a.x == 0
a.x = torch.tensor(5)
a.reset()
assert a.x == 0
def test_update():
class A(Dummy):
def update(self, x):
self.x += x
a = A()
assert a.x == 0
assert a._computed is None
a.update(1)
assert a._computed is None
assert a.x == 1
a.update(2)
assert a.x == 3
assert a._computed is None
def test_compute():
class A(Dummy):
def update(self, x):
self.x += x
def compute(self):
return self.x
a = A()
assert 0 == a.compute()
assert 0 == a.x
a.update(1)
assert a._computed is None
assert a.compute() == 1
assert a._computed == 1
a.update(2)
assert a._computed is None
assert a.compute() == 2
assert a._computed == 2
# called without update, should return cached value
a._computed = 5
assert a.compute() == 5

View File

@ -1,323 +0,0 @@
import os
from typing import Any
import numpy as np
import pytest
import torch
import tests.base.develop_utils as tutils
from tests.base import EvalModelTemplate
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
from pytorch_lightning import Trainer
class DummyTensorMetric(TensorMetric):
def __init__(self):
super().__init__("dummy")
def forward(self, input1, input2):
assert isinstance(input1, torch.Tensor)
assert isinstance(input2, torch.Tensor)
return torch.tensor([1.0])
class DummyNumpyMetric(NumpyMetric):
def __init__(self):
super().__init__("dummy")
def forward(self, input1, input2):
assert isinstance(input1, np.ndarray)
assert isinstance(input2, np.ndarray)
return 1.0
class DummyTensorCollectionMetric(TensorMetric):
def __init__(self):
super().__init__("dummy")
def forward(self, input1, input2):
assert isinstance(input1, torch.Tensor)
assert isinstance(input2, torch.Tensor)
return 1.0, 2.0, 3.0, 4.0
@pytest.mark.parametrize("metric", [DummyTensorCollectionMetric()])
def test_collection_metric(metric: Metric):
""" Test that metric.device, metric.dtype works for metric collection """
input1, input2 = torch.tensor([1.0]), torch.tensor([2.0])
def change_and_check_device_dtype(device, dtype):
metric.to(device=device, dtype=dtype)
metric_val = metric(input1, input2)
assert not isinstance(metric_val, torch.Tensor)
if device is not None:
assert metric.device in [device, torch.device(device)]
if dtype is not None:
assert metric.dtype == dtype
devices = [None, "cpu"]
if torch.cuda.is_available():
devices += ["cuda:0"]
for device in devices:
for dtype in [None, torch.float32, torch.float64]:
change_and_check_device_dtype(device=device, dtype=dtype)
if torch.cuda.is_available():
metric.cuda(0)
assert metric.device == torch.device("cuda", index=0)
metric.cpu()
assert metric.device == torch.device("cpu")
metric.type(torch.int8)
assert metric.dtype == torch.int8
metric.float()
assert metric.dtype == torch.float32
metric.double()
assert metric.dtype == torch.float64
assert all(out.dtype == torch.float64 for out in metric(input1, input2))
if torch.cuda.is_available():
metric.cuda()
metric.half()
assert metric.dtype == torch.float16
@pytest.mark.parametrize(
"metric",
[
DummyTensorMetric(),
DummyNumpyMetric(),
],
)
def test_metric(metric: Metric):
""" Test that metric.device, metric.dtype works for single metric"""
input1, input2 = torch.tensor([1.0]), torch.tensor([2.0])
def change_and_check_device_dtype(device, dtype):
metric.to(device=device, dtype=dtype)
metric_val = metric(input1, input2)
assert isinstance(metric_val, torch.Tensor)
if device is not None:
assert metric.device in [device, torch.device(device)]
assert metric_val.device in [device, torch.device(device)]
if dtype is not None:
assert metric.dtype == dtype
assert metric_val.dtype == dtype
devices = [None, "cpu"]
if torch.cuda.is_available():
devices += ["cuda:0"]
for device in devices:
for dtype in [None, torch.float32, torch.float64]:
change_and_check_device_dtype(device=device, dtype=dtype)
if torch.cuda.is_available():
metric.cuda(0)
assert metric.device == torch.device("cuda", index=0)
assert metric(input1, input2).device == torch.device("cuda", index=0)
metric.cpu()
assert metric.device == torch.device("cpu")
assert metric(input1, input2).device == torch.device("cpu")
metric.float()
assert metric.dtype == torch.float32
assert metric(input1, input2).dtype == torch.float32
metric.double()
assert metric.dtype == torch.float64
assert metric(input1, input2).dtype == torch.float64
if torch.cuda.is_available():
metric.cuda()
metric.half()
assert metric.dtype == torch.float16
assert metric(input1, input2).dtype == torch.float16
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.parametrize("metric", [DummyTensorMetric, DummyNumpyMetric])
def test_model_pickable(tmpdir, metric: Metric):
"""Make sure that metrics are pickable by including into a model and running in multi-gpu mode"""
tutils.set_random_master_port()
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=10,
gpus=[0, 1],
distributed_backend="ddp_spawn",
)
model = EvalModelTemplate()
model.metric = metric()
model.training_step = model.training_step__using_metrics
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
# correct result and ok accuracy
assert result == 1, "ddp model failed to complete"
@pytest.mark.parametrize("metric", [DummyTensorMetric(), DummyNumpyMetric()])
def test_saving_pickable(tmpdir, metric: Metric):
""" Make sure that metrics are pickable by saving and loading them using torch """
x, y = torch.randn(10,), torch.randn(
10,
)
results_before_save = metric(x, y)
# save metric
save_path = os.path.join(tmpdir, "save_test.ckpt")
torch.save(metric, save_path)
# load metric
new_metric = torch.load(save_path)
results_after_load = new_metric(x, y)
# Check metric value is the same
assert results_before_save == results_after_load
def test_correct_call_order():
""" Check that hooks are called in the expected order """
class DummyMetric(Metric):
def __init__(self):
super().__init__("dummy")
self.call_history = ["init"]
@staticmethod
def input_convert(self, data: Any):
self.call_history.append("input_convert")
return super(DummyMetric, self).input_convert(self, data)
def forward(self, tensor1, tensor2):
self.call_history.append("forward")
return tensor1 - tensor2
@staticmethod
def output_convert(self, data: Any, output: Any):
self.call_history.append("output_convert")
return super(DummyMetric, self).output_convert(self, data, output)
def ddp_sync(self, tensor: Any):
self.call_history.append("ddp_sync")
return super().ddp_sync(tensor)
@staticmethod
def ddp_reduce(self, data: Any, output: Any):
self.call_history.append("ddp_reduce")
return super(DummyMetric, self).ddp_reduce(self, data, output)
def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor:
self.call_history.append("aggregate")
return super().aggregate(*tensors)
def reset(self):
self.call_history.append("reset")
return super().reset()
@property
def aggregated(self) -> torch.Tensor:
self.call_history.append("aggregated")
return super().aggregated
@staticmethod
def compute(self, data: Any, output: Any):
self.call_history.append("compute")
return super(DummyMetric, self).compute(self, data, output)
metric = DummyMetric()
assert metric.call_history == ["init"]
result = metric(torch.tensor([2.0]), torch.tensor([1.0]))
assert torch.allclose(result, torch.tensor(1.0))
assert metric.call_history == [
"init",
"input_convert",
"forward",
"output_convert",
"ddp_reduce",
"ddp_sync",
"aggregate",
"compute"
]
aggr = metric.aggregated
assert metric.call_history == [
"init",
"input_convert",
"forward",
"output_convert",
"ddp_reduce",
"ddp_sync",
"aggregate",
"compute",
"aggregated",
"aggregate",
"reset",
"compute"
]
assert torch.allclose(aggr, result)
_ = metric(torch.tensor(2.0), torch.tensor(1.0))
assert metric.call_history == [
"init",
"input_convert",
"forward",
"output_convert",
"ddp_reduce",
"ddp_sync",
"aggregate",
"compute",
"aggregated",
"aggregate",
"reset",
"compute",
"input_convert",
"forward",
"output_convert",
"ddp_reduce",
"ddp_sync",
"aggregate",
"compute"
]
metric = DummyMetric()
_ = metric(torch.tensor([2.0]), torch.tensor([1.0]))
_ = metric(torch.tensor([3.0]), torch.tensor([0.0]))
aggregated = metric.aggregated
assert torch.allclose(aggregated, torch.tensor(4.0))
assert metric.call_history == [
"init",
"input_convert",
"forward",
"output_convert",
"ddp_reduce",
"ddp_sync",
"aggregate",
"compute",
"input_convert",
"forward",
"output_convert",
"ddp_reduce",
"ddp_sync",
"aggregate",
"compute",
"aggregated",
"aggregate",
"reset",
"compute",
]

View File

@ -1,29 +0,0 @@
import pytest
import torch
from pytorch_lightning.metrics.nlp import BLEUScore
# example taken from
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu
HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split()
HYP2 = "he read the book because he was interested in world history".split()
REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split()
REF1B = "It is a guiding principle which makes the military forces always being under the command of the Party".split()
REF1C = "It is the practical guide for the army always to heed the directions of the party".split()
REF2A = "he was interested in world history because he read the book".split()
LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]]
HYPOTHESES = [HYP1, HYP2]
@pytest.mark.parametrize(
["n_gram", "smooth"],
[pytest.param(1, True), pytest.param(2, False), pytest.param(3, True), pytest.param(4, False),],
)
def test_bleu(smooth, n_gram):
bleu = BLEUScore(n_gram=n_gram, smooth=smooth)
assert bleu.name == "bleu"
pl_output = bleu(HYPOTHESES, LIST_OF_REFERENCES)
assert isinstance(pl_output, torch.Tensor)

View File

@ -1,69 +0,0 @@
# NOTE: This file only tests if modules with arguments are running fine.
# The actual metric implementation is tested in functional/test_regression.py
# Especially reduction and reducing across processes won't be tested here!
import torch
from pytorch_lightning.metrics.regression import (
MAE, MSE, RMSE, RMSLE, PSNR, SSIM
)
def test_mse():
mse = MSE()
assert mse.name == 'mse'
pred = torch.tensor([0., 1, 2, 3])
target = torch.tensor([0., 1, 2, 2])
score = mse(pred, target)
assert isinstance(score, torch.Tensor)
def test_rmse():
rmse = RMSE()
assert rmse.name == 'rmse'
pred = torch.tensor([0., 1, 2, 3])
target = torch.tensor([0., 1, 2, 2])
score = rmse(pred, target)
assert isinstance(score, torch.Tensor)
def test_mae():
mae = MAE()
assert mae.name == 'mae'
pred = torch.tensor([0., 1, 2, 3])
target = torch.tensor([0., 1, 2, 2])
score = mae(pred, target)
assert isinstance(score, torch.Tensor)
def test_rmsle():
rmsle = RMSLE()
assert rmsle.name == 'rmsle'
pred = torch.tensor([0., 1, 2, 3])
target = torch.tensor([0., 1, 2, 2])
score = rmsle(pred, target)
assert isinstance(score, torch.Tensor)
def test_psnr():
psnr = PSNR()
assert psnr.name == 'psnr'
pred = torch.tensor([0., 1, 2, 3])
target = torch.tensor([0., 1, 2, 2])
score = psnr(pred, target)
assert isinstance(score, torch.Tensor)
def test_ssim():
ssim = SSIM()
assert ssim.name == 'ssim'
pred = torch.rand([16, 1, 16, 16])
target = pred * 0.75
score = ssim(pred, target)
assert isinstance(score, torch.Tensor)

View File

@ -1,178 +0,0 @@
import numbers
from functools import partial
import numpy as np
import pytest
import torch
from sklearn.metrics import (
accuracy_score as sk_accuracy,
precision_score as sk_precision,
recall_score as sk_recall,
f1_score as sk_f1_score,
fbeta_score as sk_fbeta_score,
confusion_matrix as sk_confusion_matrix,
average_precision_score as sk_average_precision,
auc as sk_auc,
precision_recall_curve as sk_precision_recall_curve,
roc_curve as sk_roc_curve,
roc_auc_score as sk_roc_auc_score,
balanced_accuracy_score as sk_balanced_accuracy_score,
dcg_score as sk_dcg_score,
mean_absolute_error as sk_mean_absolute_error,
mean_squared_error as sk_mean_squared_error,
mean_squared_log_error as sk_mean_squared_log_error,
median_absolute_error as sk_median_absolute_error,
r2_score as sk_r2_score,
mean_poisson_deviance as sk_mean_poisson_deviance,
mean_gamma_deviance as sk_mean_gamma_deviance,
mean_tweedie_deviance as sk_mean_tweedie_deviance,
explained_variance_score as sk_explained_variance_score,
cohen_kappa_score as sk_cohen_kappa_score,
hamming_loss as sk_hamming_loss,
hinge_loss as sk_hinge_loss,
jaccard_score as sk_jaccard_score
)
from pytorch_lightning.metrics.converters import convert_to_numpy
from pytorch_lightning.metrics.sklearns import (
Accuracy,
AUC,
AveragePrecision,
BalancedAccuracy,
ConfusionMatrix,
CohenKappaScore,
DCG,
F1,
FBeta,
Hamming,
Hinge,
Jaccard,
Precision,
Recall,
PrecisionRecallCurve,
ROC,
AUROC,
MeanAbsoluteError,
MeanSquaredError,
MeanSquaredLogError,
MedianAbsoluteError,
R2Score,
MeanPoissonDeviance,
MeanGammaDeviance,
MeanTweedieDeviance,
ExplainedVariance,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
def _xy_only(func):
def new_func(*args, **kwargs):
return np.array(func(*args, **kwargs)[:2])
return new_func
@pytest.mark.parametrize(['metric_class', 'sklearn_func', 'inputs'], [
pytest.param(Accuracy(), sk_accuracy,
{'y_pred': torch.randint(10, size=(128,)),
'y_true': torch.randint(10, size=(128,))},
id='Accuracy'),
pytest.param(AUC(), sk_auc,
{'x': torch.arange(10, dtype=torch.float) / 10,
'y': torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.3, 0.5, 0.6, 0.7])},
id='AUC'),
pytest.param(AveragePrecision(), sk_average_precision,
{'y_score': torch.randint(2, size=(128,)),
'y_true': torch.randint(2, size=(128,))},
id='AveragePrecision'),
pytest.param(ConfusionMatrix(), sk_confusion_matrix,
{'y_pred': torch.randint(10, size=(128,)),
'y_true': torch.randint(10, size=(128,))},
id='ConfusionMatrix'),
pytest.param(F1(average='macro'), partial(sk_f1_score, average='macro'),
{'y_pred': torch.randint(10, size=(128,)),
'y_true': torch.randint(10, size=(128,))},
id='F1'),
pytest.param(FBeta(beta=0.5, average='macro'), partial(sk_fbeta_score, beta=0.5, average='macro'),
{'y_pred': torch.randint(10, size=(128,)),
'y_true': torch.randint(10, size=(128,))},
id='FBeta'),
pytest.param(Precision(average='macro'), partial(sk_precision, average='macro'),
{'y_pred': torch.randint(10, size=(128,)),
'y_true': torch.randint(10, size=(128,))},
id='Precision'),
pytest.param(Recall(average='macro'), partial(sk_recall, average='macro'),
{'y_pred': torch.randint(10, size=(128,)),
'y_true': torch.randint(10, size=(128,))},
id='Recall'),
pytest.param(PrecisionRecallCurve(), _xy_only(sk_precision_recall_curve),
{'probas_pred': torch.rand(size=(128,)),
'y_true': torch.randint(2, size=(128,))},
id='PrecisionRecallCurve'),
pytest.param(ROC(), _xy_only(sk_roc_curve),
{'y_score': torch.rand(size=(128,)),
'y_true': torch.randint(2, size=(128,))},
id='ROC'),
pytest.param(AUROC(), sk_roc_auc_score,
{'y_score': torch.rand(size=(128,)),
'y_true': torch.randint(2, size=(128,))},
id='AUROC'),
pytest.param(BalancedAccuracy(), sk_balanced_accuracy_score,
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
id='BalancedAccuracy'),
pytest.param(DCG(), sk_dcg_score,
{'y_score': torch.rand(size=(128, 3)), 'y_true': torch.randint(3, size=(128, 3))},
id='DCG'),
pytest.param(ExplainedVariance(), sk_explained_variance_score,
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
id='ExplainedVariance'),
pytest.param(MeanAbsoluteError(), sk_mean_absolute_error,
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
id='MeanAbsolutError'),
pytest.param(MeanSquaredError(), sk_mean_squared_error,
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
id='MeanSquaredError'),
pytest.param(MeanSquaredLogError(), sk_mean_squared_log_error,
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
id='MeanSquaredLogError'),
pytest.param(MedianAbsoluteError(), sk_median_absolute_error,
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
id='MedianAbsoluteError'),
pytest.param(R2Score(), sk_r2_score,
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
id='R2Score'),
pytest.param(MeanPoissonDeviance(), sk_mean_poisson_deviance,
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
id='MeanPoissonDeviance'),
pytest.param(MeanGammaDeviance(), sk_mean_gamma_deviance,
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
id='MeanGammaDeviance'),
pytest.param(MeanTweedieDeviance(), sk_mean_tweedie_deviance,
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
id='MeanTweedieDeviance'),
pytest.param(CohenKappaScore(), sk_cohen_kappa_score,
{'y1': torch.randint(3, size=(128,)), 'y2': torch.randint(3, size=(128,))},
id='CohenKappaScore'),
pytest.param(Hamming(), sk_hamming_loss,
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
id='Hamming'),
pytest.param(Hinge(), sk_hinge_loss,
{'pred_decision': torch.randn(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
id='Hinge'),
pytest.param(Jaccard(average='macro'), partial(sk_jaccard_score, average='macro'),
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
id='Jaccard')
])
def test_sklearn_metric(metric_class, sklearn_func, inputs):
numpy_inputs = apply_to_collection(inputs, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)
sklearn_result = sklearn_func(**numpy_inputs)
lightning_result = metric_class(**inputs)
sklearn_result = apply_to_collection(
sklearn_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)
lightning_result = np.array(apply_to_collection(
lightning_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy))
assert np.allclose(sklearn_result, lightning_result, atol=1e-5)
assert isinstance(lightning_result, type(sklearn_result))

61
tests/metrics/utils.py Normal file
View File

@ -0,0 +1,61 @@
import torch
import numpy as np
import os
import sys
import pytest
NUM_PROCESSES = 2
NUM_BATCHES = 10
BATCH_SIZE = 16
def setup_ddp(rank, world_size):
os.environ["MASTER_ADDR"] = 'localhost'
os.environ['MASTER_PORT'] = '8088'
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
def _compute_batch(rank, preds, target, metric_class, sk_metric, ddp_sync_on_step, worldsize=1, metric_args={}):
metric = metric_class(compute_on_step=True, ddp_sync_on_step=ddp_sync_on_step, **metric_args)
# Only use ddp if world size
if worldsize > 1:
setup_ddp(rank, worldsize)
for i in range(rank, NUM_BATCHES, worldsize):
batch_result = metric(preds[i], target[i])
if metric.ddp_sync_on_step:
if rank == 0:
ddp_preds = torch.stack([preds[i + r] for r in range(worldsize)])
ddp_target = torch.stack([target[i + r] for r in range(worldsize)])
sk_batch_result = sk_metric(ddp_preds, ddp_target)
assert np.allclose(batch_result.numpy(), sk_batch_result)
else:
sk_batch_result = sk_metric(preds[i], target[i])
assert np.allclose(batch_result.numpy(), sk_batch_result)
# check on all batches on all ranks
result = metric.compute()
assert isinstance(result, torch.Tensor)
total_preds = torch.stack([preds[i] for i in range(NUM_BATCHES)])
total_target = torch.stack([target[i] for i in range(NUM_BATCHES)])
sk_result = sk_metric(total_preds, total_target)
assert np.allclose(result.numpy(), sk_result)
def compute_batch(preds, target, metric_class, sk_metric, ddp_sync_on_step, ddp=False, metric_args={}):
if ddp:
if sys.platform == "win32":
pytest.skip("DDP not supported on windows")
torch.multiprocessing.spawn(
_compute_batch, args=(preds, target, metric_class, sk_metric, ddp_sync_on_step, NUM_PROCESSES, metric_args),
nprocs=NUM_PROCESSES
)
else:
# first args: rank, last args: world size
_compute_batch(0, preds, target, metric_class, sk_metric, ddp_sync_on_step, 1, metric_args)