revamp entire metrics (#3868)
* removed metric Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * added new metrics Co-authored-by: Teddy Koker teddy.koker@gmail.com * pep8 Co-authored-by: Teddy Koker teddy.koker@gmail.com * pep8 Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * docs Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * docs Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * win ddp tests skip Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * win ddp tests skip Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * win ddp tests skip Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * win ddp tests skip Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * reset in compute, cache compute Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * reduce_ops handling Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * sync -> sync_dist, type annotations Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * wip docs Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * mean squared error * docstring * added mean ___ error metrics * added mean ___ error metrics * seperated files * accuracy doctest * gpu fix * remove unnecessary mixin * metric and accuracy docstring Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * metric docs Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * pep8, changelog Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * refactor dist utils, pep8 * refactor dist utils, pep8 Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
This commit is contained in:
parent
4722cc0bf0
commit
f76bc5254e
|
@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
- Added new Metrics API. ([#3868](https://github.com/PyTorchLightning/pytorch-lightning/pull/3868))
|
||||
|
||||
- Enable PyTorch 1.7 compatibility ([#3541](https://github.com/PyTorchLightning/pytorch-lightning/pull/3541))
|
||||
|
||||
- Added hooks to metric module interface ([#2528](https://github.com/PyTorchLightning/pytorch-lightning/pull/2528))
|
||||
|
@ -63,6 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Removed
|
||||
|
||||
- Remove old Metrics API. ([#3868](https://github.com/PyTorchLightning/pytorch-lightning/pull/3868))
|
||||
|
||||
### Fixed
|
||||
|
||||
|
|
|
@ -3,124 +3,124 @@
|
|||
import torch
|
||||
from torch.nn import Module
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.metrics import TensorMetric, NumpyMetric
|
||||
from pytorch_lightning.metrics import Metric
|
||||
|
||||
.. _metrics:
|
||||
|
||||
Metrics
|
||||
=======
|
||||
This is a general package for PyTorch Metrics. These can also be used with regular non-lightning PyTorch code.
|
||||
Metrics are used to monitor model performance.
|
||||
|
||||
In this package, we provide two major pieces of functionality.
|
||||
``pytorch_lightning.metrics`` is a Metrics API created for easy metric development and usage in
|
||||
PyTorch and PyTorch Lightning. It is rigorously tested for all edge cases and includes a growing list of
|
||||
common metric implementations.
|
||||
|
||||
1. A Metric class you can use to implement metrics with built-in distributed (ddp) support which are device agnostic.
|
||||
2. A collection of ready to use popular metrics. There are two types of metrics: Class metrics and Functional metrics.
|
||||
3. An interface to call `sklearns metrics <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_
|
||||
The metrics API provides ``update()``, ``compute()``, ``reset()`` functions to the user. The metric base class inherits
|
||||
``nn.Module`` which allows us to call ``metric(...)`` directly. The ``forward()`` method of the base ``Metric`` class
|
||||
serves the dual purpose of calling ``update()`` on its input and simultanously returning the value of the metric over the
|
||||
provided input.
|
||||
|
||||
Example::
|
||||
These metrics work with DDP in PyTorch and PyTorch Lightning by default. When ``.compute()`` is called in
|
||||
distributed mode, the internal state of each metric is synced and reduced across each process, so that the
|
||||
logic present in ``.compute()`` is applied to state information from all processes.
|
||||
|
||||
from pytorch_lightning.metrics.functional import accuracy
|
||||
The example below shows how to use a metric in your ``LightningModule``:
|
||||
|
||||
pred = torch.tensor([0, 1, 2, 3])
|
||||
target = torch.tensor([0, 1, 2, 2])
|
||||
.. note::
|
||||
|
||||
# calculates accuracy across all GPUs and all Nodes used in training
|
||||
accuracy(pred, target)
|
||||
|
||||
.. warning::
|
||||
The metrics package is still in development! If we're missing a metric or you find a mistake, please send a PR!
|
||||
to a few metrics. Please feel free to create an issue/PR if you have a proposed metric or have found a bug.
|
||||
|
||||
----------------
|
||||
|
||||
Implement a metric
|
||||
------------------
|
||||
You can implement metrics as either a PyTorch metric or a Numpy metric (It is recommended to use PyTorch metrics when possible,
|
||||
since Numpy metrics slow down training).
|
||||
|
||||
Use :class:`TensorMetric` to implement native PyTorch metrics. This class
|
||||
handles automated DDP syncing and converts all inputs and outputs to tensors.
|
||||
|
||||
Use :class:`NumpyMetric` to implement numpy metrics. This class
|
||||
handles automated DDP syncing and converts all inputs and outputs to tensors.
|
||||
|
||||
.. warning::
|
||||
Numpy metrics might slow down your training substantially,
|
||||
since every metric computation requires a GPU sync to convert tensors to numpy.
|
||||
|
||||
----------------
|
||||
|
||||
TensorMetric
|
||||
^^^^^^^^^^^^
|
||||
Here's an example showing how to implement a TensorMetric
|
||||
|
||||
.. testcode::
|
||||
|
||||
class RMSE(TensorMetric):
|
||||
def forward(self, x, y):
|
||||
return torch.sqrt(torch.mean(torch.pow(x-y, 2.0)))
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.metric.TensorMetric
|
||||
:noindex:
|
||||
|
||||
----------------
|
||||
|
||||
NumpyMetric
|
||||
^^^^^^^^^^^
|
||||
Here's an example showing how to implement a NumpyMetric
|
||||
|
||||
.. testcode::
|
||||
|
||||
class RMSE(NumpyMetric):
|
||||
def forward(self, x, y):
|
||||
return np.sqrt(np.mean(np.power(x-y, 2.0)))
|
||||
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.metric.NumpyMetric
|
||||
:noindex:
|
||||
|
||||
----------------
|
||||
|
||||
Class Metrics
|
||||
-------------
|
||||
Class metrics can be instantiated as part of a module definition (even with just
|
||||
plain PyTorch).
|
||||
|
||||
.. testcode::
|
||||
|
||||
from pytorch_lightning.metrics import Accuracy
|
||||
|
||||
# Plain PyTorch
|
||||
class MyModule(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.metric = Accuracy()
|
||||
|
||||
def forward(self, x, y):
|
||||
y_hat = ...
|
||||
acc = self.metric(y_hat, y)
|
||||
|
||||
# PyTorch Lightning
|
||||
class MyModule(LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.metric = Accuracy()
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
x, y = batch
|
||||
y_hat = ...
|
||||
acc = self.metric(y_hat, y)
|
||||
|
||||
These metrics even work when using distributed training:
|
||||
For v0.10.0 the user is expected to call ``.compute()`` on the metric at the end each epoch.
|
||||
This has been shown in the example below. For v1.0 release, we will integrate metrics
|
||||
with logging and ``.compute()`` will be called automatically by PyTorch Lightning.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = MyModule()
|
||||
trainer = Trainer(gpus=8, num_nodes=2)
|
||||
def __init__(self):
|
||||
...
|
||||
self.accuracy = pl.metrics.Accuracy()
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
logits = self(x)
|
||||
...
|
||||
# log step metric
|
||||
self.log('train_acc_step', self.accuracy(logits, y))
|
||||
...
|
||||
|
||||
def training_epoch_end(self, outs):
|
||||
# log epoch metric
|
||||
self.log('train_acc_epoch', self.accuracy.compute())
|
||||
|
||||
# any metric automatically reduces across GPUs (even the ones you implement using Lightning)
|
||||
trainer.fit(model)
|
||||
|
||||
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning import metrics
|
||||
|
||||
train_accuracy = metrics.Accuracy()
|
||||
valid_accuracy = metrics.Accuracy(compute_on_step=False)
|
||||
|
||||
for epoch in range(epochs):
|
||||
for x, y in train_data:
|
||||
y_hat = model(x)
|
||||
|
||||
# training step accuracy
|
||||
batch_acc = train_accuracy(y_hat, y)
|
||||
|
||||
for x, y in valid_data:
|
||||
y_hat = model(x)
|
||||
valid_accuracy(y_hat, y)
|
||||
|
||||
# total accuracy over all training batches
|
||||
total_train_accuracy = train_accuracy.compute()
|
||||
|
||||
# total accuracy over all validation batches
|
||||
total_valid_accuracy = train_accuracy.compute()
|
||||
|
||||
Implementing a Metric
|
||||
---------------------
|
||||
|
||||
To implement your custom metric, subclass the base ``Metric`` class and implement the following methods:
|
||||
|
||||
- ``__init__()``: Each state variable should be called using ``self.add_state(...)``.
|
||||
- ``update()``: Any code needed to update the state given any inputs to the metric.
|
||||
- ``compute()``: Computes a final value from the state of the metric.
|
||||
|
||||
All you need to do is call add_state correctly to implement a custom metric with DDP.
|
||||
``reset()`` is called on metric state variables added using ``add_state()``.
|
||||
|
||||
To see how metric states are synchronized across distributed processes, refer to ``add_state()`` docs
|
||||
from the base ``Metric`` class.
|
||||
|
||||
Example implementation:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning.metrics import Metric
|
||||
|
||||
class MyAccuracy(Metric):
|
||||
def __init__(self, ddp_sync_on_step=False):
|
||||
super().__init__(ddp_sync_on_step=ddp_sync_on_step)
|
||||
|
||||
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
preds, target = self._input_format(preds, target)
|
||||
assert preds.shape == target.shape
|
||||
|
||||
self.correct += torch.sum(preds == target)
|
||||
self.total += target.numel()
|
||||
|
||||
def compute(self):
|
||||
return self.correct.float() / self.total
|
||||
|
||||
Metric
|
||||
^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.Metric
|
||||
:noindex:
|
||||
|
||||
Classification Metrics
|
||||
----------------------
|
||||
|
||||
Accuracy
|
||||
^^^^^^^^
|
||||
|
@ -128,510 +128,25 @@ Accuracy
|
|||
.. autoclass:: pytorch_lightning.metrics.classification.Accuracy
|
||||
:noindex:
|
||||
|
||||
AveragePrecision
|
||||
Regression Metrics
|
||||
------------------
|
||||
|
||||
MeanSquaredError
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.AveragePrecision
|
||||
.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError
|
||||
:noindex:
|
||||
|
||||
AUROC
|
||||
^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.AUROC
|
||||
MeanAbsoluteError
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError
|
||||
:noindex:
|
||||
|
||||
BLEUScore
|
||||
^^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.nlp.BLEUScore
|
||||
:noindex:
|
||||
|
||||
ConfusionMatrix
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix
|
||||
:noindex:
|
||||
|
||||
DiceCoefficient
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.DiceCoefficient
|
||||
:noindex:
|
||||
|
||||
EmbeddingSimilarity
|
||||
MeanSquaredLogError
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.self_supervised.EmbeddingSimilarity
|
||||
:noindex:
|
||||
|
||||
F1
|
||||
^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.F1
|
||||
:noindex:
|
||||
|
||||
FBeta
|
||||
^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.FBeta
|
||||
:noindex:
|
||||
|
||||
PrecisionRecallCurve
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecallCurve
|
||||
:noindex:
|
||||
|
||||
Precision
|
||||
^^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.Precision
|
||||
:noindex:
|
||||
|
||||
Recall
|
||||
^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.Recall
|
||||
:noindex:
|
||||
|
||||
ROC
|
||||
^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.ROC
|
||||
:noindex:
|
||||
|
||||
MAE
|
||||
^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.regression.MAE
|
||||
:noindex:
|
||||
|
||||
MSE
|
||||
^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.regression.MSE
|
||||
:noindex:
|
||||
|
||||
MulticlassROC
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.MulticlassROC
|
||||
:noindex:
|
||||
|
||||
MulticlassPrecisionRecallCurve
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.MulticlassPrecisionRecallCurve
|
||||
:noindex:
|
||||
|
||||
IoU
|
||||
^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.classification.IoU
|
||||
:noindex:
|
||||
|
||||
RMSE
|
||||
^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.regression.RMSE
|
||||
:noindex:
|
||||
|
||||
RMSLE
|
||||
^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.regression.RMSLE
|
||||
:noindex:
|
||||
|
||||
SSIM
|
||||
^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.regression.SSIM
|
||||
:noindex:
|
||||
|
||||
----------------
|
||||
|
||||
Functional Metrics
|
||||
------------------
|
||||
Functional metrics can be called anywhere (even used with just plain PyTorch).
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning.metrics.functional import accuracy
|
||||
|
||||
pred = torch.tensor([0, 1, 2, 3])
|
||||
target = torch.tensor([0, 1, 2, 2])
|
||||
|
||||
# calculates accuracy across all GPUs and all Nodes used in training
|
||||
accuracy(pred, target)
|
||||
|
||||
These metrics even work when using distributed training:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyModule(...):
|
||||
def forward(self, x, y):
|
||||
return accuracy(x, y)
|
||||
|
||||
model = MyModule()
|
||||
trainer = Trainer(gpus=8, num_nodes=2)
|
||||
|
||||
# any metric automatically reduces across GPUs (even the ones you implement using Lightning)
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
accuracy (F)
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.accuracy
|
||||
:noindex:
|
||||
|
||||
auc (F)
|
||||
^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.auc
|
||||
:noindex:
|
||||
|
||||
auroc (F)
|
||||
^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.auroc
|
||||
:noindex:
|
||||
|
||||
average_precision (F)
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.average_precision
|
||||
:noindex:
|
||||
|
||||
bleu_score (F)
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.bleu_score
|
||||
:noindex:
|
||||
|
||||
confusion_matrix (F)
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.confusion_matrix
|
||||
:noindex:
|
||||
|
||||
dice_score (F)
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.dice_score
|
||||
:noindex:
|
||||
|
||||
embedding_similarity (F)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.embedding_similarity
|
||||
:noindex:
|
||||
|
||||
f1_score (F)
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.f1_score
|
||||
:noindex:
|
||||
|
||||
fbeta_score (F)
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.fbeta_score
|
||||
:noindex:
|
||||
|
||||
multiclass_precision_recall_curve (F)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.multiclass_precision_recall_curve
|
||||
:noindex:
|
||||
|
||||
multiclass_roc (F)
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.multiclass_roc
|
||||
:noindex:
|
||||
|
||||
precision (F)
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.precision
|
||||
:noindex:
|
||||
|
||||
precision_recall (F)
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.precision_recall
|
||||
:noindex:
|
||||
|
||||
precision_recall_curve (F)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.precision_recall_curve
|
||||
:noindex:
|
||||
|
||||
recall (F)
|
||||
^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.recall
|
||||
:noindex:
|
||||
|
||||
roc (F)
|
||||
^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.roc
|
||||
:noindex:
|
||||
|
||||
stat_scores (F)
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.stat_scores
|
||||
:noindex:
|
||||
|
||||
iou (F)
|
||||
^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.iou
|
||||
:noindex:
|
||||
|
||||
mse (F)
|
||||
^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.mse
|
||||
:noindex:
|
||||
|
||||
rmse (F)
|
||||
^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.rmse
|
||||
:noindex:
|
||||
|
||||
mae (F)
|
||||
^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.mae
|
||||
:noindex:
|
||||
|
||||
rmsle (F)
|
||||
^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.rmsle
|
||||
:noindex:
|
||||
|
||||
psnr (F)
|
||||
^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.psnr
|
||||
:noindex:
|
||||
|
||||
ssim (F)
|
||||
^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.ssim
|
||||
:noindex:
|
||||
|
||||
stat_scores_multiple_classes (F)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.stat_scores_multiple_classes
|
||||
:noindex:
|
||||
|
||||
----------------
|
||||
|
||||
Metric pre-processing
|
||||
---------------------
|
||||
|
||||
to_categorical (F)
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.to_categorical
|
||||
:noindex:
|
||||
|
||||
to_onehot (F)
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.functional.to_onehot
|
||||
:noindex:
|
||||
|
||||
----------------
|
||||
|
||||
Sklearn interface
|
||||
-----------------
|
||||
|
||||
Lightning supports `sklearns metrics module <https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics>`_
|
||||
as a backend for calculating metrics. Sklearns metrics are well tested and robust,
|
||||
but requires conversion between pytorch and numpy thus may slow down your computations.
|
||||
|
||||
To use the sklearn backend of metrics simply import as
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import pytorch_lightning.metrics.sklearns import plm
|
||||
metric = plm.Accuracy(normalize=True)
|
||||
val = metric(pred, target)
|
||||
|
||||
Each converted sklearn metric comes has the same interface as its
|
||||
original counterpart (e.g. accuracy takes the additional `normalize` keyword).
|
||||
Like the native Lightning metrics, these converted sklearn metrics also come
|
||||
with built-in distributed (ddp) support.
|
||||
|
||||
SklearnMetric (sk)
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.SklearnMetric
|
||||
:noindex:
|
||||
|
||||
Accuracy (sk)
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.Accuracy
|
||||
:noindex:
|
||||
|
||||
AUC (sk)
|
||||
^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.AUC
|
||||
:noindex:
|
||||
|
||||
AveragePrecision (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.AveragePrecision
|
||||
:noindex:
|
||||
|
||||
BalancedAccuracy (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.BalancedAccuracy
|
||||
:noindex:
|
||||
|
||||
CohenKappaScore (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.CohenKappaScore
|
||||
:noindex:
|
||||
|
||||
ConfusionMatrix (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.ConfusionMatrix
|
||||
:noindex:
|
||||
|
||||
DCG (sk)
|
||||
^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.DCG
|
||||
:noindex:
|
||||
|
||||
F1 (sk)
|
||||
^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.F1
|
||||
:noindex:
|
||||
|
||||
FBeta (sk)
|
||||
^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.FBeta
|
||||
:noindex:
|
||||
|
||||
Hamming (sk)
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.Hamming
|
||||
:noindex:
|
||||
|
||||
Hinge (sk)
|
||||
^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.Hinge
|
||||
:noindex:
|
||||
|
||||
Jaccard (sk)
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.Jaccard
|
||||
:noindex:
|
||||
|
||||
Precision (sk)
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.Precision
|
||||
:noindex:
|
||||
|
||||
Recall (sk)
|
||||
^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.Recall
|
||||
:noindex:
|
||||
|
||||
PrecisionRecallCurve (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.PrecisionRecallCurve
|
||||
:noindex:
|
||||
|
||||
ROC (sk)
|
||||
^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.ROC
|
||||
:noindex:
|
||||
|
||||
AUROC (sk)
|
||||
^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.AUROC
|
||||
:noindex:
|
||||
|
||||
ExplainedVariance (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.ExplainedVariance
|
||||
:noindex:
|
||||
|
||||
MeanAbsoluteError (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanAbsoluteError
|
||||
:noindex:
|
||||
|
||||
MeanSquaredError (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanSquaredError
|
||||
:noindex:
|
||||
|
||||
MeanSquaredLogError (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanSquaredLogError
|
||||
:noindex:
|
||||
|
||||
MedianAbsoluteError (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.MedianAbsoluteError
|
||||
:noindex:
|
||||
|
||||
R2Score (sk)
|
||||
^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.R2Score
|
||||
:noindex:
|
||||
|
||||
MeanPoissonDeviance (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanPoissonDeviance
|
||||
:noindex:
|
||||
|
||||
MeanGammaDeviance (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanGammaDeviance
|
||||
:noindex:
|
||||
|
||||
MeanTweedieDeviance (sk)
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: pytorch_lightning.metrics.sklearns.MeanTweedieDeviance
|
||||
.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError
|
||||
:noindex:
|
||||
|
|
|
@ -14,14 +14,13 @@
|
|||
|
||||
import numbers
|
||||
from copy import copy
|
||||
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any, List, Tuple
|
||||
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any, List, Tuple, Iterable
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import os
|
||||
|
||||
from pytorch_lightning.metrics.converters import sync_ddp_if_available
|
||||
from typing import Iterable
|
||||
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
|
||||
|
||||
|
||||
class Result(Dict):
|
||||
|
|
|
@ -1,66 +1,4 @@
|
|||
from pytorch_lightning.metrics.classification import (
|
||||
Accuracy,
|
||||
AveragePrecision,
|
||||
ConfusionMatrix,
|
||||
F1,
|
||||
FBeta,
|
||||
Recall,
|
||||
ROC,
|
||||
AUROC,
|
||||
DiceCoefficient,
|
||||
MulticlassPrecisionRecallCurve,
|
||||
MulticlassROC,
|
||||
Precision,
|
||||
PrecisionRecallCurve,
|
||||
IoU,
|
||||
)
|
||||
from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric
|
||||
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
|
||||
from pytorch_lightning.metrics.nlp import BLEUScore
|
||||
from pytorch_lightning.metrics.self_supervised import EmbeddingSimilarity
|
||||
from pytorch_lightning.metrics.regression import (
|
||||
MAE,
|
||||
MSE,
|
||||
PSNR,
|
||||
RMSE,
|
||||
RMSLE,
|
||||
SSIM
|
||||
)
|
||||
from pytorch_lightning.metrics.sklearns import (
|
||||
AUC,
|
||||
SklearnMetric,
|
||||
)
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
|
||||
__classification_metrics = [
|
||||
"AUC",
|
||||
"AUROC",
|
||||
"Accuracy",
|
||||
"AveragePrecision",
|
||||
"ConfusionMatrix",
|
||||
"DiceCoefficient",
|
||||
"F1",
|
||||
"FBeta",
|
||||
"MulticlassPrecisionRecallCurve",
|
||||
"MulticlassROC",
|
||||
"Precision",
|
||||
"PrecisionRecallCurve",
|
||||
"ROC",
|
||||
"Recall",
|
||||
"IoU",
|
||||
]
|
||||
__regression_metrics = [
|
||||
"MAE",
|
||||
"MSE",
|
||||
"PSNR",
|
||||
"RMSE",
|
||||
"RMSLE",
|
||||
"SSIM"
|
||||
]
|
||||
__sequence_metrics = ["BLEUScore"]
|
||||
__selfsuper_metrics = ["EmbeddingSimilarity"]
|
||||
|
||||
__all__ = __regression_metrics \
|
||||
+ __classification_metrics \
|
||||
+ __selfsuper_metrics \
|
||||
+ __sequence_metrics \
|
||||
+ ["SklearnMetric"]
|
||||
from pytorch_lightning.metrics.classification.accuracy import Accuracy
|
||||
from pytorch_lightning.metrics.regression import MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError
|
||||
|
|
|
@ -1,866 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.classification import (
|
||||
accuracy,
|
||||
auroc,
|
||||
average_precision,
|
||||
confusion_matrix,
|
||||
_confmat_normalize,
|
||||
dice_score,
|
||||
f1_score,
|
||||
fbeta_score,
|
||||
iou,
|
||||
multiclass_precision_recall_curve,
|
||||
multiclass_roc,
|
||||
precision_recall_curve,
|
||||
roc,
|
||||
precision_recall
|
||||
)
|
||||
from pytorch_lightning.metrics.functional.reduction import class_reduce
|
||||
from pytorch_lightning.metrics.metric import TensorMetric
|
||||
|
||||
|
||||
class Accuracy(TensorMetric):
|
||||
"""
|
||||
Computes the accuracy classification score
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = Accuracy()
|
||||
>>> metric(pred, target)
|
||||
tensor(0.7500)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
class_reduction: str = 'micro',
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: number of classes
|
||||
class_reduction: method to reduce metric score over labels
|
||||
|
||||
- ``'micro'``: calculate metrics globally (default)
|
||||
- ``'macro'``: calculate metrics for each label, and find their unweighted mean.
|
||||
- ``'weighted'``: calculate metrics for each label, and find their weighted mean.
|
||||
- ``'none'``: returns calculated metric per class
|
||||
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
"""
|
||||
super().__init__(name="accuracy", reduce_group=reduce_group)
|
||||
self.num_classes = num_classes
|
||||
assert class_reduction in ('micro', 'macro', 'weighted', 'none')
|
||||
self.class_reduction = class_reduction
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: ground truth labels
|
||||
|
||||
Return:
|
||||
A Tensor with the classification score.
|
||||
"""
|
||||
return accuracy(pred=pred, target=target,
|
||||
num_classes=self.num_classes,
|
||||
class_reduction='none',
|
||||
return_state=True)
|
||||
|
||||
@staticmethod
|
||||
def compute(self, data: Any, output: Any):
|
||||
tps, sups = output['tps'], output['sups']
|
||||
return class_reduce(tps, sups, sups, class_reduction=self.class_reduction)
|
||||
|
||||
|
||||
class ConfusionMatrix(TensorMetric):
|
||||
"""
|
||||
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
|
||||
in group i that were predicted in group j.
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 2])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = ConfusionMatrix()
|
||||
>>> metric(pred, target)
|
||||
tensor([[1., 0., 0.],
|
||||
[0., 1., 0.],
|
||||
[0., 0., 2.]])
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
normalize: bool = False,
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: number of classes
|
||||
normalize: whether to compute a normalized confusion matrix
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
"""
|
||||
super().__init__(
|
||||
name="confusion_matrix",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
self.normalize = normalize
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: ground truth labels
|
||||
|
||||
Return:
|
||||
A Tensor with the confusion matrix.
|
||||
"""
|
||||
return confusion_matrix(pred=pred, target=target,
|
||||
normalize=False, # we normalize after ddp sync
|
||||
num_classes=self.num_classes)
|
||||
|
||||
@staticmethod
|
||||
def compute(self, data: Any, output: Any):
|
||||
""" Confusion matrix normalization needs to happen after ddp sync """
|
||||
confmat = output
|
||||
if self.normalize:
|
||||
confmat = _confmat_normalize(confmat)
|
||||
return confmat
|
||||
|
||||
|
||||
class PrecisionRecallCurve(TensorMetric):
|
||||
"""
|
||||
Computes the precision recall curve
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = PrecisionRecallCurve()
|
||||
>>> prec, recall, thr = metric(pred, target)
|
||||
>>> prec
|
||||
tensor([0.3333, 0.0000, 0.0000, 1.0000])
|
||||
>>> recall
|
||||
tensor([1., 0., 0., 0.])
|
||||
>>> thr
|
||||
tensor([1., 2., 3.])
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pos_label: int = 1,
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
pos_label: positive label indicator
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
"""
|
||||
super().__init__(
|
||||
name="precision_recall_curve",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.pos_label = pos_label
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: groundtruth labels
|
||||
sample_weight: the weights per sample
|
||||
|
||||
Return:
|
||||
- precision values
|
||||
- recall values
|
||||
- threshold values
|
||||
"""
|
||||
return precision_recall_curve(pred=pred, target=target,
|
||||
sample_weight=sample_weight, pos_label=self.pos_label)
|
||||
|
||||
|
||||
class Precision(TensorMetric):
|
||||
"""
|
||||
Computes the precision score
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = Precision(num_classes=4, class_reduction='macro')
|
||||
>>> metric(pred, target)
|
||||
tensor(0.7500)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
class_reduction: str = 'micro',
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: number of classes
|
||||
class_reduction: method to reduce metric score over labels
|
||||
|
||||
- ``'micro'``: calculate metrics globally (default)
|
||||
- ``'macro'``: calculate metrics for each label, and find their unweighted mean.
|
||||
- ``'weighted'``: calculate metrics for each label, and find their weighted mean.
|
||||
- ``'none'``: returns calculated metric per class
|
||||
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
"""
|
||||
super().__init__(
|
||||
name="precision",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
self.num_classes = num_classes
|
||||
assert class_reduction in ('micro', 'macro', 'weighted', 'none')
|
||||
self.class_reduction = class_reduction
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: ground truth labels
|
||||
|
||||
Return:
|
||||
A Tensor with the classification score.
|
||||
"""
|
||||
return precision_recall(pred=pred, target=target,
|
||||
num_classes=self.num_classes,
|
||||
class_reduction='none',
|
||||
return_state=True)
|
||||
|
||||
@staticmethod
|
||||
def compute(self, data: Any, output: Any):
|
||||
tps, fps, sups = output['tps'], output['fps'], output['sups']
|
||||
return class_reduce(tps, tps + fps, sups, class_reduction=self.class_reduction)
|
||||
|
||||
|
||||
class Recall(TensorMetric):
|
||||
"""
|
||||
Computes the recall score
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = Recall()
|
||||
>>> metric(pred, target)
|
||||
tensor(0.7500)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
class_reduction: str = 'micro',
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: number of classes
|
||||
class_reduction: method to reduce metric score over labels
|
||||
|
||||
- ``'micro'``: calculate metrics globally (default)
|
||||
- ``'macro'``: calculate metrics for each label, and find their unweighted mean.
|
||||
- ``'weighted'``: calculate metrics for each label, and find their weighted mean.
|
||||
- ``'none'``: returns calculated metric per class
|
||||
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
"""
|
||||
super().__init__(
|
||||
name="recall",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.num_classes = num_classes
|
||||
assert class_reduction in ('micro', 'macro', 'weighted', 'none')
|
||||
self.class_reduction = class_reduction
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: ground truth labels
|
||||
|
||||
Return:
|
||||
A Tensor with the classification score.
|
||||
"""
|
||||
return precision_recall(pred=pred, target=target,
|
||||
num_classes=self.num_classes,
|
||||
class_reduction='none',
|
||||
return_state=True)
|
||||
|
||||
@staticmethod
|
||||
def compute(self, data: Any, output: Any):
|
||||
tps, fns, sups = output['tps'], output['fns'], output['sups']
|
||||
return class_reduce(tps, tps + fns, sups, class_reduction=self.class_reduction)
|
||||
|
||||
|
||||
class AveragePrecision(TensorMetric):
|
||||
"""
|
||||
Computes the average precision score
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = AveragePrecision()
|
||||
>>> metric(pred, target)
|
||||
tensor(0.3333)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pos_label: int = 1,
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
pos_label: positive label indicator
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
"""
|
||||
super().__init__(
|
||||
name="AP",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.pos_label = pos_label
|
||||
|
||||
def forward(
|
||||
self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: groundtruth labels
|
||||
sample_weight: the weights per sample
|
||||
|
||||
Return:
|
||||
torch.Tensor: classification score
|
||||
"""
|
||||
return average_precision(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label)
|
||||
|
||||
|
||||
class AUROC(TensorMetric):
|
||||
"""
|
||||
Computes the area under curve (AUC) of the receiver operator characteristic (ROC)
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 1, 0])
|
||||
>>> metric = AUROC()
|
||||
>>> metric(pred, target)
|
||||
tensor(0.5000)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pos_label: int = 1,
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
pos_label: positive label indicator
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
"""
|
||||
super().__init__(
|
||||
name="auroc",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.pos_label = pos_label
|
||||
|
||||
def forward(
|
||||
self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: groundtruth labels
|
||||
sample_weight: the weights per sample
|
||||
|
||||
Return:
|
||||
torch.Tensor: classification score
|
||||
"""
|
||||
return auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label)
|
||||
|
||||
|
||||
class FBeta(TensorMetric):
|
||||
"""
|
||||
Computes the FBeta Score, which is the weighted harmonic mean of precision and recall.
|
||||
It ranges between 1 and 0, where 1 is perfect and the worst value is 0.
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = FBeta(0.25, class_reduction='macro')
|
||||
>>> metric(pred, target)
|
||||
tensor(0.7361)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
beta: float,
|
||||
num_classes: Optional[int] = None,
|
||||
class_reduction: str = 'micro',
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
beta: determines the weight of recall in the combined score.
|
||||
num_classes: number of classes
|
||||
class_reduction: method to reduce metric score over labels
|
||||
|
||||
- ``'micro'``: calculate metrics globally (default)
|
||||
- ``'macro'``: calculate metrics for each label, and find their unweighted mean.
|
||||
- ``'weighted'``: calculate metrics for each label, and find their weighted mean.
|
||||
- ``'none'``: returns calculated metric per class
|
||||
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
"""
|
||||
super().__init__(
|
||||
name="fbeta",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.beta = beta
|
||||
self.num_classes = num_classes
|
||||
assert class_reduction in ('micro', 'macro', 'weighted', 'none')
|
||||
self.class_reduction = class_reduction
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: groundtruth labels
|
||||
|
||||
Return:
|
||||
torch.Tensor: classification score
|
||||
"""
|
||||
return precision_recall(pred=pred, target=target,
|
||||
num_classes=self.num_classes,
|
||||
class_reduction='none',
|
||||
return_state=True)
|
||||
|
||||
@staticmethod
|
||||
def compute(self, data: Any, output: Any):
|
||||
""" tps, fps, fns, sups needs to be synced before we do any calculations """
|
||||
tps, fps, fns, sups = output['tps'], output['fps'], output['fns'], output['sups']
|
||||
|
||||
intermidiate_reduction = 'none' if self.class_reduction != "micro" else 'micro'
|
||||
precision = class_reduce(tps, tps + fps, sups, class_reduction=intermidiate_reduction)
|
||||
recall = class_reduce(tps, tps + fns, sups, class_reduction=intermidiate_reduction)
|
||||
|
||||
num = (1 + self.beta ** 2) * precision * recall
|
||||
denom = ((self.beta ** 2) * precision + recall)
|
||||
if intermidiate_reduction == 'micro':
|
||||
return torch.sum(num) / torch.sum(denom)
|
||||
return class_reduce(num, denom, sups, class_reduction=self.class_reduction)
|
||||
|
||||
|
||||
class F1(FBeta):
|
||||
"""
|
||||
Computes the F1 score, which is the harmonic mean of the precision and recall.
|
||||
It ranges between 1 and 0, where 1 is perfect and the worst value is 0.
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = F1(class_reduction='macro')
|
||||
>>> metric(pred, target)
|
||||
tensor(0.6667)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
class_reduction: str = 'micro',
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: number of classes
|
||||
class_reduction: method to reduce metric score over labels
|
||||
|
||||
- ``'micro'``: calculate metrics globally (default)
|
||||
- ``'macro'``: calculate metrics for each label, and find their unweighted mean.
|
||||
- ``'weighted'``: calculate metrics for each label, and find their weighted mean.
|
||||
- ``'none'``: returns calculated metric per class
|
||||
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
"""
|
||||
super().__init__(beta=1.0,
|
||||
num_classes=num_classes,
|
||||
class_reduction=class_reduction,
|
||||
reduce_group=reduce_group)
|
||||
self.name = "f1"
|
||||
|
||||
|
||||
class ROC(TensorMetric):
|
||||
"""
|
||||
Computes the Receiver Operator Characteristic (ROC)
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0, 1, 2, 3])
|
||||
>>> target = torch.tensor([0, 1, 2, 2])
|
||||
>>> metric = ROC()
|
||||
>>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE
|
||||
(tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]),
|
||||
tensor([0., 0., 0., 1., 1.]),
|
||||
tensor([4., 3., 2., 1., 0.]))
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pos_label: int = 1,
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
pos_label: positive label indicator
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
"""
|
||||
super().__init__(
|
||||
name="roc",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.pos_label = pos_label
|
||||
|
||||
def forward(
|
||||
self, pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: groundtruth labels
|
||||
sample_weight: the weights per sample
|
||||
|
||||
Return:
|
||||
- false positive rate
|
||||
- true positive rate
|
||||
- thresholds
|
||||
"""
|
||||
return roc(pred=pred, target=target, sample_weight=sample_weight, pos_label=self.pos_label)
|
||||
|
||||
|
||||
class MulticlassROC(TensorMetric):
|
||||
"""
|
||||
Computes the multiclass ROC
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||
... [0.05, 0.85, 0.05, 0.05],
|
||||
... [0.05, 0.05, 0.85, 0.05],
|
||||
... [0.05, 0.05, 0.05, 0.85]])
|
||||
>>> target = torch.tensor([0, 1, 3, 2])
|
||||
>>> metric = MulticlassROC()
|
||||
>>> classes_roc = metric(pred, target)
|
||||
>>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE
|
||||
((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
|
||||
(tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
|
||||
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])),
|
||||
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: number of classes
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
"""
|
||||
super().__init__(
|
||||
name="multiclass_roc",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted probability for each label
|
||||
target: groundtruth labels
|
||||
sample_weight: Weights for each sample defining the sample's impact on the score
|
||||
|
||||
Return:
|
||||
tuple: A tuple consisting of one tuple per class, holding false positive rate, true positive rate and thresholds
|
||||
|
||||
"""
|
||||
return multiclass_roc(pred=pred, target=target, sample_weight=sample_weight, num_classes=self.num_classes)
|
||||
|
||||
def aggregate(self, *tensors: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
"""Aggregates results by stacking them instead of concatenating before averaging.
|
||||
|
||||
Returns:
|
||||
the aggregated results
|
||||
"""
|
||||
|
||||
return tuple([tuple([torch.stack(tmps).mean(0) for tmps in zip(*_tensors)]) for _tensors in zip(*tensors)])
|
||||
|
||||
|
||||
class MulticlassPrecisionRecallCurve(TensorMetric):
|
||||
"""Computes the multiclass PR Curve
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||
... [0.05, 0.85, 0.05, 0.05],
|
||||
... [0.05, 0.05, 0.85, 0.05],
|
||||
... [0.05, 0.05, 0.05, 0.85]])
|
||||
>>> target = torch.tensor([0, 1, 3, 2])
|
||||
>>> metric = MulticlassPrecisionRecallCurve()
|
||||
>>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE
|
||||
((tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])),
|
||||
(tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])),
|
||||
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])),
|
||||
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: Optional[int] = None,
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: number of classes
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
|
||||
"""
|
||||
super().__init__(
|
||||
name="multiclass_precision_recall_curve",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
sample_weight: Optional[Sequence] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted probability for each label
|
||||
target: groundtruth labels
|
||||
sample_weight: Weights for each sample defining the sample's impact on the score
|
||||
|
||||
Return:
|
||||
tuple: A tuple consisting of one tuple per class, holding precision, recall and thresholds
|
||||
|
||||
"""
|
||||
return multiclass_precision_recall_curve(
|
||||
pred=pred, target=target, sample_weight=sample_weight, num_classes=self.num_classes
|
||||
)
|
||||
|
||||
def aggregate(self, *tensors: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
"""Aggregates results by stacking them instead of concatenating before averaging.
|
||||
|
||||
Returns:
|
||||
the aggregated results
|
||||
"""
|
||||
|
||||
return tuple([tuple([torch.stack(tmps).mean(0) for tmps in zip(*_tensors)]) for _tensors in zip(*tensors)])
|
||||
|
||||
|
||||
class DiceCoefficient(TensorMetric):
|
||||
"""
|
||||
Computes the dice coefficient
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||
... [0.05, 0.85, 0.05, 0.05],
|
||||
... [0.05, 0.05, 0.85, 0.05],
|
||||
... [0.05, 0.05, 0.05, 0.85]])
|
||||
>>> target = torch.tensor([0, 1, 3, 2])
|
||||
>>> metric = DiceCoefficient()
|
||||
>>> metric(pred, target)
|
||||
tensor(0.3333)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
include_background: bool = False,
|
||||
nan_score: float = 0.0,
|
||||
no_fg_score: float = 0.0,
|
||||
reduction: str = "elementwise_mean",
|
||||
reduce_group: Any = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
include_background: whether to also compute dice for the background
|
||||
nan_score: score to return, if a NaN occurs during computation (denom zero)
|
||||
no_fg_score: score to return, if no foreground pixel was found in target
|
||||
reduction: a method to reduce metric score over labels.
|
||||
|
||||
- ``'elementwise_mean'``: takes the mean (default)
|
||||
- ``'sum'``: takes the sum
|
||||
- ``'none'``: no reduction will be applied
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
"""
|
||||
super().__init__(
|
||||
name="dice",
|
||||
reduce_group=reduce_group,
|
||||
)
|
||||
|
||||
self.include_background = include_background
|
||||
self.nan_score = nan_score
|
||||
self.no_fg_score = no_fg_score
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted probability for each label
|
||||
target: groundtruth labels
|
||||
|
||||
Return:
|
||||
torch.Tensor: the calculated dice coefficient
|
||||
"""
|
||||
return dice_score(
|
||||
pred=pred,
|
||||
target=target,
|
||||
bg=self.include_background,
|
||||
nan_score=self.nan_score,
|
||||
no_fg_score=self.no_fg_score,
|
||||
reduction=self.reduction,
|
||||
)
|
||||
|
||||
|
||||
class IoU(TensorMetric):
|
||||
"""
|
||||
Computes the intersection over union.
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0],
|
||||
... [0, 0, 1, 1, 1, 0, 0, 0],
|
||||
... [0, 0, 0, 0, 0, 0, 0, 0]])
|
||||
>>> target = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0],
|
||||
... [0, 0, 0, 1, 1, 1, 0, 0],
|
||||
... [0, 0, 0, 0, 0, 0, 0, 0]])
|
||||
>>> metric = IoU()
|
||||
>>> metric(pred, target)
|
||||
tensor(0.7045)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ignore_index: Optional[int] = None,
|
||||
absent_score: float = 0.0,
|
||||
num_classes: Optional[int] = None,
|
||||
reduction: str = "elementwise_mean",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
ignore_index: optional int specifying a target class to ignore. If given, this class index does not
|
||||
contribute to the returned score, regardless of reduction method. Has no effect if given an int that is
|
||||
not in the range [0, num_classes-1], where num_classes is either given or derived from pred and target.
|
||||
By default, no index is ignored, and all classes are used.
|
||||
absent_score: score to use for an individual class, if no instances of the class index were present in
|
||||
`y_pred` AND no instances of the class index were present in `y_true`. For example, if we have 3
|
||||
classes, [0, 0] for `y_pred`, and [0, 2] for `y_true`, then class 1 would be assigned the
|
||||
`absent_score`. Default is 0.0.
|
||||
num_classes: Optionally specify the number of classes
|
||||
reduction: a method to reduce metric score over labels.
|
||||
|
||||
- ``'elementwise_mean'``: takes the mean (default)
|
||||
- ``'sum'``: takes the sum
|
||||
- ``'none'``: no reduction will be applied
|
||||
"""
|
||||
super().__init__(name="iou")
|
||||
self.ignore_index = ignore_index
|
||||
self.absent_score = absent_score
|
||||
self.num_classes = num_classes
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor, sample_weight: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Actual metric calculation.
|
||||
"""
|
||||
return iou(
|
||||
pred=y_pred,
|
||||
target=y_true,
|
||||
ignore_index=self.ignore_index,
|
||||
absent_score=self.absent_score,
|
||||
num_classes=self.num_classes,
|
||||
reduction=self.reduction,
|
||||
)
|
|
@ -0,0 +1 @@
|
|||
from pytorch_lightning.metrics.classification.accuracy import Accuracy
|
|
@ -0,0 +1,103 @@
|
|||
import math
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
|
||||
|
||||
class Accuracy(Metric):
|
||||
"""
|
||||
Computes accuracy. Works with binary, multiclass, and multilabel data.
|
||||
Accepts logits from a model output or integer class values in prediction.
|
||||
Works with multi-dimensional preds and target.
|
||||
|
||||
Forward accepts
|
||||
|
||||
- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
|
||||
- ``target`` (long tensor): ``(N, ...)``
|
||||
|
||||
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
|
||||
This is the case for binary and multi-label logits.
|
||||
|
||||
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
|
||||
|
||||
Args:
|
||||
threshold:
|
||||
Threshold value for binary or multi-label logits. default: 0.5
|
||||
compute_on_step:
|
||||
Forward only calls ``update()`` and return None if this is set to False. default: True
|
||||
ddp_sync_on_step:
|
||||
Synchronize metric state across processes at each ``forward()``
|
||||
before returning the value at the step. default: False
|
||||
process_group:
|
||||
Specify the process group on which synchronization is called. default: None (which selects the entire world)
|
||||
|
||||
Example:
|
||||
|
||||
>>> from pytorch_lightning.metrics import Accuracy
|
||||
>>> target = torch.tensor([0, 1, 2, 3])
|
||||
>>> preds = torch.tensor([0, 2, 1, 3])
|
||||
>>> accuracy = Accuracy()
|
||||
>>> accuracy(preds, target)
|
||||
tensor(0.5000)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
threshold: float = 0.5,
|
||||
compute_on_step: bool = True,
|
||||
ddp_sync_on_step: bool = False,
|
||||
process_group: Optional[Any] = None,
|
||||
):
|
||||
super().__init__(
|
||||
compute_on_step=compute_on_step,
|
||||
ddp_sync_on_step=ddp_sync_on_step,
|
||||
process_group=process_group,
|
||||
)
|
||||
|
||||
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
|
||||
self.threshold = threshold
|
||||
|
||||
def _input_format(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1):
|
||||
raise ValueError(
|
||||
"preds and target must have same number of dimensions, or one additional dimension for preds"
|
||||
)
|
||||
|
||||
if len(preds.shape) == len(target.shape) + 1:
|
||||
# multi class probabilites
|
||||
preds = torch.argmax(preds, dim=1)
|
||||
|
||||
if len(preds.shape) == len(target.shape) and preds.dtype == torch.float:
|
||||
# binary or multilabel probablities
|
||||
preds = (preds >= self.threshold).long()
|
||||
|
||||
return preds, target
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
Args:
|
||||
preds: Predictions from model
|
||||
target: Ground truth values
|
||||
"""
|
||||
preds, target = self._input_format(preds, target)
|
||||
assert preds.shape == target.shape
|
||||
|
||||
self.correct += torch.sum(preds == target)
|
||||
self.total += target.numel()
|
||||
|
||||
def compute(self):
|
||||
"""
|
||||
Computes accuracy over state.
|
||||
"""
|
||||
return self.correct.float() / self.total
|
|
@ -1,410 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file provides functions and decorators for automated input and output
|
||||
conversion to/from :class:`numpy.ndarray` and :class:`torch.Tensor` as well as utilities to
|
||||
sync tensors between different processes in a DDP scenario, when needed.
|
||||
"""
|
||||
|
||||
from functools import reduce
|
||||
import numbers
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data._utils.collate import np_str_obj_array_pattern
|
||||
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
||||
if torch.distributed.is_available():
|
||||
from torch.distributed import ReduceOp
|
||||
else:
|
||||
class ReduceOp:
|
||||
SUM = None
|
||||
|
||||
|
||||
def _apply_to_inputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
|
||||
"""
|
||||
Decorator function to apply a function to all inputs of a function.
|
||||
|
||||
Args:
|
||||
func_to_apply: the function to apply to the inputs
|
||||
*dec_args: positional arguments for the function to be applied
|
||||
**dec_kwargs: keyword arguments for the function to be applied
|
||||
|
||||
Return:
|
||||
the decorated function
|
||||
"""
|
||||
|
||||
def decorator_fn(func_to_decorate):
|
||||
# actual function applying the give function to inputs
|
||||
def new_func(*args, **kwargs):
|
||||
args = func_to_apply(args, *dec_args, **dec_kwargs)
|
||||
kwargs = func_to_apply(kwargs, *dec_args, **dec_kwargs)
|
||||
return func_to_decorate(*args, **kwargs)
|
||||
|
||||
return new_func
|
||||
|
||||
return decorator_fn
|
||||
|
||||
|
||||
def _apply_to_outputs(func_to_apply: Callable, *dec_args, **dec_kwargs) -> Callable:
|
||||
"""
|
||||
Decorator function to apply a function to all outputs of a function.
|
||||
|
||||
Args:
|
||||
func_to_apply: the function to apply to the outputs
|
||||
*dec_args: positional arguments for the function to be applied
|
||||
**dec_kwargs: keyword arguments for the function to be applied
|
||||
|
||||
Return:
|
||||
the decorated function
|
||||
"""
|
||||
|
||||
def decorator_fn(function_to_decorate):
|
||||
# actual function applying the give function to outputs
|
||||
def new_func(*args, **kwargs):
|
||||
result = function_to_decorate(*args, **kwargs)
|
||||
return func_to_apply(result, *dec_args, **dec_kwargs)
|
||||
|
||||
return new_func
|
||||
|
||||
return decorator_fn
|
||||
|
||||
|
||||
def convert_to_tensor(data: Any, dtype=None, device=None) -> Any:
|
||||
"""
|
||||
Maps all kind of collections and numbers to tensors.
|
||||
|
||||
Args:
|
||||
data: the data to convert to tensor
|
||||
dtype: data type to convert to
|
||||
device: device to cast to
|
||||
|
||||
Return:
|
||||
the converted data
|
||||
"""
|
||||
if isinstance(data, numbers.Number):
|
||||
return torch.tensor([data], dtype=dtype, device=device)
|
||||
# is not array of object
|
||||
elif isinstance(data, np.ndarray) and np_str_obj_array_pattern.search(data.dtype.str) is None:
|
||||
return torch.from_numpy(data).to(device=device, dtype=dtype)
|
||||
elif isinstance(data, torch.Tensor):
|
||||
return data.to(device=device, dtype=dtype)
|
||||
|
||||
raise TypeError(f"The given type ('{type(data).__name__}') cannot be converted to a tensor!")
|
||||
|
||||
|
||||
def convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray:
|
||||
"""Convert all tensors and numpy arrays to numpy arrays.
|
||||
|
||||
Args:
|
||||
data: the tensor or array to convert to numpy
|
||||
|
||||
Return:
|
||||
the resulting numpy array
|
||||
"""
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.cpu().detach().numpy()
|
||||
elif isinstance(data, numbers.Number):
|
||||
return np.array([data])
|
||||
elif isinstance(data, np.ndarray):
|
||||
return data
|
||||
|
||||
raise TypeError("The given type ('%s') cannot be converted to a numpy array!" % type(data).__name__)
|
||||
|
||||
|
||||
def _numpy_metric_input_conversion(func_to_decorate: Callable) -> Callable:
|
||||
"""
|
||||
Decorator converting all inputs of a function to numpy
|
||||
|
||||
Args:
|
||||
func_to_decorate: the function whose inputs shall be converted
|
||||
|
||||
Return:
|
||||
Callable: the decorated function
|
||||
"""
|
||||
return _apply_to_inputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)(
|
||||
func_to_decorate
|
||||
)
|
||||
|
||||
|
||||
def _tensor_metric_output_conversion(func_to_decorate: Callable) -> Callable:
|
||||
"""
|
||||
Decorator converting all outputs of a function to tensors
|
||||
|
||||
Args:
|
||||
func_to_decorate: the function whose outputs shall be converted
|
||||
|
||||
Return:
|
||||
Callable: the decorated function
|
||||
"""
|
||||
return _apply_to_outputs(convert_to_tensor)(func_to_decorate)
|
||||
|
||||
|
||||
def _numpy_metric_conversion(func_to_decorate: Callable) -> Callable:
|
||||
"""
|
||||
Decorator handling the argument conversion for metrics working on numpy.
|
||||
All inputs of the decorated function will be converted to numpy and all
|
||||
outputs will be converted to tensors.
|
||||
|
||||
Args:
|
||||
func_to_decorate: the function whose inputs and outputs shall be converted
|
||||
|
||||
Return:
|
||||
the decorated function
|
||||
"""
|
||||
# applies collection conversion from tensor to numpy to all inputs
|
||||
# we need to include numpy arrays here, since otherwise they will also be treated as sequences
|
||||
func_convert_inputs = _numpy_metric_input_conversion(func_to_decorate)
|
||||
# converts all inputs back to tensors (device doesn't matter here, since this is handled by BaseMetric)
|
||||
func_convert_in_out = _tensor_metric_output_conversion(func_convert_inputs)
|
||||
return func_convert_in_out
|
||||
|
||||
|
||||
def _tensor_metric_input_conversion(func_to_decorate: Callable) -> Callable:
|
||||
"""
|
||||
Decorator converting all inputs of a function to tensors
|
||||
|
||||
Args:
|
||||
func_to_decorate: the function whose inputs shall be converted
|
||||
|
||||
Return:
|
||||
Callable: the decorated function
|
||||
"""
|
||||
return _apply_to_inputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(
|
||||
func_to_decorate
|
||||
)
|
||||
|
||||
|
||||
def _tensor_collection_metric_output_conversion(func_to_decorate: Callable) -> Callable:
|
||||
"""
|
||||
Decorator converting all numpy arrays and numbers occuring in the outputs of a function to tensors
|
||||
|
||||
Args:
|
||||
func_to_decorate: the function whose outputs shall be converted
|
||||
|
||||
Return:
|
||||
Callable: the decorated function
|
||||
"""
|
||||
return _apply_to_outputs(apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor)(
|
||||
func_to_decorate
|
||||
)
|
||||
|
||||
|
||||
def _tensor_metric_conversion(func_to_decorate: Callable) -> Callable:
|
||||
"""
|
||||
Decorator Handling the argument conversion for metrics working on tensors.
|
||||
All inputs and outputs of the decorated function will be converted to tensors
|
||||
|
||||
Args:
|
||||
func_to_decorate: the function whose inputs and outputs shall be converted
|
||||
|
||||
Return:
|
||||
the decorated function
|
||||
"""
|
||||
# converts all inputs to tensor if possible
|
||||
# we need to include tensors here, since otherwise they will also be treated as sequences
|
||||
func_convert_inputs = _tensor_metric_input_conversion(func_to_decorate)
|
||||
# convert all outputs to tensor if possible
|
||||
return _tensor_metric_output_conversion(func_convert_inputs)
|
||||
|
||||
|
||||
def _tensor_collection_metric_conversion(func_to_decorate: Callable) -> Callable:
|
||||
"""
|
||||
Decorator Handling the argument conversion for metrics working on tensors.
|
||||
All inputs of the decorated function and all numpy arrays and numbers in
|
||||
it's outputs will be converted to tensors
|
||||
|
||||
Args:
|
||||
func_to_decorate: the function whose inputs and outputs shall be converted
|
||||
|
||||
Return:
|
||||
the decorated function
|
||||
"""
|
||||
# converts all inputs to tensor if possible
|
||||
# we need to include tensors here, since otherwise they will also be treated as sequences
|
||||
func_convert_inputs = _tensor_metric_input_conversion(func_to_decorate)
|
||||
# convert all outputs to tensor if possible
|
||||
return _tensor_collection_metric_output_conversion(func_convert_inputs)
|
||||
|
||||
|
||||
def sync_ddp_if_available(
|
||||
result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Function to reduce the tensors from several ddp processes to one master process
|
||||
|
||||
Args:
|
||||
result: the value to sync and reduce (typically tensor or number)
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: the reduction operation. Defaults to sum.
|
||||
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
|
||||
|
||||
Return:
|
||||
reduced value
|
||||
"""
|
||||
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
divide_by_world_size = False
|
||||
|
||||
if group is None:
|
||||
group = torch.distributed.group.WORLD
|
||||
|
||||
if reduce_op is None:
|
||||
reduce_op = torch.distributed.ReduceOp.SUM
|
||||
elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
|
||||
reduce_op = torch.distributed.ReduceOp.SUM
|
||||
divide_by_world_size = True
|
||||
|
||||
# sync all processes before reduction
|
||||
torch.distributed.barrier(group=group)
|
||||
torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False)
|
||||
|
||||
if divide_by_world_size:
|
||||
result = result / torch.distributed.get_world_size(group)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def at_least_1d(tensor: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
"""Makes sure the tensor is at least of 1d shape
|
||||
|
||||
Args:
|
||||
tensor: the tensor or array to check the shape for
|
||||
|
||||
Returns:
|
||||
the optionally reshaped tensor
|
||||
"""
|
||||
if tensor.shape == ():
|
||||
tensor = tensor.reshape(1, )
|
||||
return tensor
|
||||
|
||||
|
||||
def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None):
|
||||
"""
|
||||
Function to gather all tensors from several ddp processes onto a list that
|
||||
is broadcasted to all processes
|
||||
|
||||
Args:
|
||||
result: the value to sync
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
|
||||
Return:
|
||||
gathered_result: list with size equal to the process group where
|
||||
gathered_result[i] corresponds to result tensor from process i
|
||||
|
||||
"""
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
if group is None:
|
||||
group = torch.distributed.group.WORLD
|
||||
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
|
||||
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
|
||||
|
||||
# sync and broadcast all
|
||||
torch.distributed.barrier(group=group)
|
||||
torch.distributed.all_gather(gathered_result, result, group)
|
||||
|
||||
result = gathered_result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def sync_ddp(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
|
||||
"""
|
||||
This decorator syncs a functions outputs across different processes for DDP.
|
||||
|
||||
Args:
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: the reduction operation. Defaults to sum
|
||||
|
||||
Return:
|
||||
the decorated function
|
||||
|
||||
"""
|
||||
|
||||
def decorator_fn(func_to_decorate):
|
||||
return _apply_to_outputs(
|
||||
apply_to_collection, torch.Tensor, sync_ddp_if_available, group=group, reduce_op=reduce_op
|
||||
)(func_to_decorate)
|
||||
|
||||
return decorator_fn
|
||||
|
||||
|
||||
def numpy_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
|
||||
"""
|
||||
This decorator shall be used on all function metrics working on numpy arrays.
|
||||
It handles the argument conversion and DDP reduction for metrics working on numpy.
|
||||
All inputs of the decorated function will be converted to numpy and all
|
||||
outputs will be converted to tensors.
|
||||
In DDP Training all output tensors will be reduced according to the given rules.
|
||||
|
||||
Args:
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: the reduction operation. Defaults to sum
|
||||
|
||||
Return:
|
||||
the decorated function
|
||||
"""
|
||||
|
||||
def decorator_fn(func_to_decorate):
|
||||
return sync_ddp(group=group, reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate))
|
||||
|
||||
return decorator_fn
|
||||
|
||||
|
||||
def tensor_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
|
||||
"""
|
||||
This decorator shall be used on all function metrics working on tensors.
|
||||
It handles the argument conversion and DDP reduction for metrics working on tensors.
|
||||
All inputs and outputs of the decorated function will be converted to tensors.
|
||||
In DDP Training all output tensors will be reduced according to the given rules.
|
||||
|
||||
Args:
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: the reduction operation. Defaults to sum
|
||||
|
||||
Return:
|
||||
the decorated function
|
||||
"""
|
||||
|
||||
def decorator_fn(func_to_decorate):
|
||||
return sync_ddp(group=group, reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate))
|
||||
|
||||
return decorator_fn
|
||||
|
||||
|
||||
def tensor_collection_metric(group: Optional[Any] = None, reduce_op: Optional[ReduceOp] = None) -> Callable:
|
||||
"""
|
||||
This decorator shall be used on all function metrics working on tensors and returning collections
|
||||
that cannot be converted to tensors.
|
||||
It handles the argument conversion and DDP reduction for metrics working on tensors.
|
||||
All inputs and outputs of the decorated function will be converted to tensors.
|
||||
In DDP Training all output tensors will be reduced according to the given rules.
|
||||
|
||||
Args:
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: the reduction operation. Defaults to sum
|
||||
|
||||
Return:
|
||||
the decorated function
|
||||
"""
|
||||
|
||||
def decorator_fn(func_to_decorate):
|
||||
return sync_ddp(group=group, reduce_op=reduce_op)(_tensor_collection_metric_conversion(func_to_decorate))
|
||||
|
||||
return decorator_fn
|
|
@ -1,34 +0,0 @@
|
|||
from pytorch_lightning.metrics.functional.classification import (
|
||||
accuracy,
|
||||
auc,
|
||||
auroc,
|
||||
average_precision,
|
||||
confusion_matrix,
|
||||
dice_score,
|
||||
f1_score,
|
||||
fbeta_score,
|
||||
multiclass_precision_recall_curve,
|
||||
multiclass_roc,
|
||||
precision,
|
||||
precision_recall,
|
||||
precision_recall_curve,
|
||||
recall,
|
||||
roc,
|
||||
stat_scores,
|
||||
stat_scores_multiple_classes,
|
||||
to_categorical,
|
||||
to_onehot,
|
||||
iou,
|
||||
)
|
||||
from pytorch_lightning.metrics.functional.nlp import bleu_score
|
||||
from pytorch_lightning.metrics.functional.regression import (
|
||||
mae,
|
||||
mse,
|
||||
psnr,
|
||||
rmse,
|
||||
rmsle,
|
||||
ssim
|
||||
)
|
||||
from pytorch_lightning.metrics.functional.self_supervised import (
|
||||
embedding_similarity
|
||||
)
|
File diff suppressed because it is too large
Load Diff
|
@ -1,103 +0,0 @@
|
|||
# referenced from
|
||||
# Library Name: torchtext
|
||||
# Authors: torchtext authors and @sluks
|
||||
# Date: 2020-07-18
|
||||
# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
|
||||
from collections import Counter
|
||||
from typing import List, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:
|
||||
"""
|
||||
Counting how many times each word appears in a given text with ngram
|
||||
|
||||
Args:
|
||||
ngram_input_list: A list of translated text or reference texts
|
||||
n_gram: gram value ranged 1 to 4
|
||||
|
||||
Return:
|
||||
ngram_counter: a collections.Counter object of ngram
|
||||
"""
|
||||
|
||||
ngram_counter = Counter()
|
||||
|
||||
for i in range(1, n_gram + 1):
|
||||
for j in range(len(ngram_input_list) - i + 1):
|
||||
ngram_key = tuple(ngram_input_list[j:(i + j)])
|
||||
ngram_counter[ngram_key] += 1
|
||||
|
||||
return ngram_counter
|
||||
|
||||
|
||||
def bleu_score(
|
||||
translate_corpus: Sequence[str],
|
||||
reference_corpus: Sequence[str],
|
||||
n_gram: int = 4,
|
||||
smooth: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate BLEU score of machine translated text with one or more references
|
||||
|
||||
Args:
|
||||
translate_corpus: An iterable of machine translated corpus
|
||||
reference_corpus: An iterable of iterables of reference corpus
|
||||
n_gram: Gram value ranged from 1 to 4 (Default 4)
|
||||
smooth: Whether or not to apply smoothing – Lin et al. 2004
|
||||
|
||||
Return:
|
||||
Tensor with BLEU Score
|
||||
|
||||
Example:
|
||||
|
||||
>>> translate_corpus = ['the cat is on the mat'.split()]
|
||||
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
|
||||
>>> bleu_score(translate_corpus, reference_corpus)
|
||||
tensor(0.7598)
|
||||
|
||||
"""
|
||||
|
||||
assert len(translate_corpus) == len(reference_corpus)
|
||||
numerator = torch.zeros(n_gram)
|
||||
denominator = torch.zeros(n_gram)
|
||||
precision_scores = torch.zeros(n_gram)
|
||||
c = 0.0
|
||||
r = 0.0
|
||||
|
||||
for (translation, references) in zip(translate_corpus, reference_corpus):
|
||||
c += len(translation)
|
||||
ref_len_list = [len(ref) for ref in references]
|
||||
ref_len_diff = [abs(len(translation) - x) for x in ref_len_list]
|
||||
r += ref_len_list[ref_len_diff.index(min(ref_len_diff))]
|
||||
translation_counter = _count_ngram(translation, n_gram)
|
||||
reference_counter = Counter()
|
||||
|
||||
for ref in references:
|
||||
reference_counter |= _count_ngram(ref, n_gram)
|
||||
|
||||
ngram_counter_clip = translation_counter & reference_counter
|
||||
|
||||
for counter_clip in ngram_counter_clip:
|
||||
numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip]
|
||||
|
||||
for counter in translation_counter:
|
||||
denominator[len(counter) - 1] += translation_counter[counter]
|
||||
|
||||
trans_len = torch.tensor(c)
|
||||
ref_len = torch.tensor(r)
|
||||
|
||||
if min(numerator) == 0.0:
|
||||
return torch.tensor(0.0)
|
||||
|
||||
if smooth:
|
||||
precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram))
|
||||
else:
|
||||
precision_scores = numerator / denominator
|
||||
|
||||
log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores)
|
||||
geometric_mean = torch.exp(torch.sum(log_precision_scores))
|
||||
brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len))
|
||||
bleu = brevity_penalty * geometric_mean
|
||||
|
||||
return bleu
|
|
@ -1,65 +0,0 @@
|
|||
import torch
|
||||
|
||||
|
||||
def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor:
|
||||
"""
|
||||
Reduces a given tensor by a given reduction method
|
||||
|
||||
Args:
|
||||
to_reduce : the tensor, which shall be reduced
|
||||
reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum')
|
||||
|
||||
Return:
|
||||
reduced Tensor
|
||||
|
||||
Raise:
|
||||
ValueError if an invalid reduction parameter was given
|
||||
"""
|
||||
if reduction == 'elementwise_mean':
|
||||
return torch.mean(to_reduce)
|
||||
if reduction == 'none':
|
||||
return to_reduce
|
||||
if reduction == 'sum':
|
||||
return torch.sum(to_reduce)
|
||||
raise ValueError('Reduction parameter unknown.')
|
||||
|
||||
|
||||
def class_reduce(num: torch.Tensor,
|
||||
denom: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
class_reduction: str = 'none') -> torch.Tensor:
|
||||
"""
|
||||
Function used to reduce classification metrics of the form `num / denom * weights`.
|
||||
For example for calculating standard accuracy the num would be number of
|
||||
true positives per class, denom would be the support per class, and weights
|
||||
would be a tensor of 1s
|
||||
|
||||
Args:
|
||||
num: numerator tensor
|
||||
decom: denominator tensor
|
||||
weights: weights for each class
|
||||
class_reduction: reduction method for multiclass problems
|
||||
|
||||
- ``'micro'``: calculate metrics globally (default)
|
||||
- ``'macro'``: calculate metrics for each label, and find their unweighted mean.
|
||||
- ``'weighted'``: calculate metrics for each label, and find their weighted mean.
|
||||
- ``'none'``: returns calculated metric per class
|
||||
|
||||
"""
|
||||
valid_reduction = ('micro', 'macro', 'weighted', 'none')
|
||||
if class_reduction == 'micro':
|
||||
return torch.sum(num) / torch.sum(denom)
|
||||
|
||||
# For the rest we need to take care of instances where the denom can be 0
|
||||
# for some classes which will produce nans for that class
|
||||
fraction = num / denom
|
||||
fraction[fraction != fraction] = 0
|
||||
if class_reduction == 'macro':
|
||||
return torch.mean(fraction)
|
||||
elif class_reduction == 'weighted':
|
||||
return torch.sum(fraction * (weights / torch.sum(weights)))
|
||||
elif class_reduction == 'none':
|
||||
return fraction
|
||||
|
||||
raise ValueError(f'Reduction parameter {class_reduction} unknown.'
|
||||
f' Choose between one of these: {valid_reduction}')
|
|
@ -1,325 +0,0 @@
|
|||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from pytorch_lightning.metrics.functional.reduction import reduce
|
||||
|
||||
|
||||
def mse(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
reduction: str = 'elementwise_mean',
|
||||
return_state: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes mean squared error
|
||||
|
||||
Args:
|
||||
pred: estimated labels
|
||||
target: ground truth labels
|
||||
reduction: a method to reduce metric score over labels.
|
||||
|
||||
- ``'elementwise_mean'``: takes the mean (default)
|
||||
- ``'sum'``: takes the sum
|
||||
- ``'none'``: no reduction will be applied
|
||||
return_state: returns a internal state that can be ddp reduced
|
||||
before doing the final calculation
|
||||
|
||||
Return:
|
||||
Tensor with MSE
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0., 1, 2, 3])
|
||||
>>> y = torch.tensor([0., 1, 2, 2])
|
||||
>>> mse(x, y)
|
||||
tensor(0.2500)
|
||||
|
||||
"""
|
||||
mse = F.mse_loss(pred, target, reduction='none')
|
||||
if return_state:
|
||||
return {'squared_error': mse.sum(), 'n_observations': torch.tensor(mse.numel())}
|
||||
mse = reduce(mse, reduction=reduction)
|
||||
return mse
|
||||
|
||||
|
||||
def rmse(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
reduction: str = 'elementwise_mean',
|
||||
return_state: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes root mean squared error
|
||||
|
||||
Args:
|
||||
pred: estimated labels
|
||||
target: ground truth labels
|
||||
reduction: a method to reduce metric score over labels.
|
||||
|
||||
- ``'elementwise_mean'``: takes the mean (default)
|
||||
- ``'sum'``: takes the sum
|
||||
- ``'none'``: no reduction will be applied
|
||||
return_state: returns a internal state that can be ddp reduced
|
||||
before doing the final calculation
|
||||
|
||||
Return:
|
||||
Tensor with RMSE
|
||||
|
||||
|
||||
>>> x = torch.tensor([0., 1, 2, 3])
|
||||
>>> y = torch.tensor([0., 1, 2, 2])
|
||||
>>> rmse(x, y)
|
||||
tensor(0.5000)
|
||||
|
||||
"""
|
||||
mean_squared_error = mse(pred, target, reduction=reduction)
|
||||
if return_state:
|
||||
return {'squared_error': mean_squared_error.sum(),
|
||||
'n_observations': torch.tensor(mean_squared_error.numel())}
|
||||
return torch.sqrt(mean_squared_error)
|
||||
|
||||
|
||||
def mae(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
reduction: str = 'elementwise_mean',
|
||||
return_state: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes mean absolute error
|
||||
|
||||
Args:
|
||||
pred: estimated labels
|
||||
target: ground truth labels
|
||||
reduction: a method to reduce metric score over labels.
|
||||
|
||||
- ``'elementwise_mean'``: takes the mean (default)
|
||||
- ``'sum'``: takes the sum
|
||||
- ``'none'``: no reduction will be applied
|
||||
return_state: returns a internal state that can be ddp reduced
|
||||
before doing the final calculation
|
||||
|
||||
Return:
|
||||
Tensor with MAE
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0., 1, 2, 3])
|
||||
>>> y = torch.tensor([0., 1, 2, 2])
|
||||
>>> mae(x, y)
|
||||
tensor(0.2500)
|
||||
|
||||
"""
|
||||
mae = F.l1_loss(pred, target, reduction='none')
|
||||
if return_state:
|
||||
return {'absolute_error': mae.sum(), 'n_observations': torch.tensor(mae.numel())}
|
||||
mae = reduce(mae, reduction=reduction)
|
||||
return mae
|
||||
|
||||
|
||||
def rmsle(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
reduction: str = 'elementwise_mean'
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes root mean squared log error
|
||||
|
||||
Args:
|
||||
pred: estimated labels
|
||||
target: ground truth labels
|
||||
reduction: a method to reduce metric score over labels.
|
||||
|
||||
- ``'elementwise_mean'``: takes the mean (default)
|
||||
- ``'sum'``: takes the sum
|
||||
- ``'none'``: no reduction will be applied
|
||||
|
||||
Return:
|
||||
Tensor with RMSLE
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([0., 1, 2, 3])
|
||||
>>> y = torch.tensor([0., 1, 2, 2])
|
||||
>>> rmsle(x, y)
|
||||
tensor(0.1438)
|
||||
|
||||
"""
|
||||
rmsle = rmse(torch.log(pred + 1), torch.log(target + 1), reduction=reduction)
|
||||
return rmsle
|
||||
|
||||
|
||||
def psnr(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
data_range: float = None,
|
||||
base: float = 10.0,
|
||||
reduction: str = 'elementwise_mean',
|
||||
return_state: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes the peak signal-to-noise ratio
|
||||
|
||||
Args:
|
||||
pred: estimated signal
|
||||
target: groun truth signal
|
||||
data_range: the range of the data. If None, it is determined from the data (max - min)
|
||||
base: a base of a logarithm to use (default: 10)
|
||||
reduction: a method to reduce metric score over labels.
|
||||
|
||||
- ``'elementwise_mean'``: takes the mean (default)
|
||||
- ``'sum'``: takes the sum
|
||||
- ``'none'``: no reduction will be applied
|
||||
return_state: returns a internal state that can be ddp reduced
|
||||
before doing the final calculation
|
||||
|
||||
Return:
|
||||
Tensor with PSNR score
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
|
||||
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
|
||||
>>> psnr(pred, target)
|
||||
tensor(2.5527)
|
||||
|
||||
"""
|
||||
if data_range is None:
|
||||
data_range = target.max() - target.min()
|
||||
else:
|
||||
data_range = torch.tensor(float(data_range))
|
||||
|
||||
if return_state:
|
||||
return {'data_range': data_range,
|
||||
'sum_squared_error': F.mse_loss(pred, target, reduction='none').sum(),
|
||||
'n_obs': torch.tensor(target.numel())}
|
||||
|
||||
mse_score = mse(pred.view(-1), target.view(-1), reduction=reduction)
|
||||
psnr_base_e = 2 * torch.log(data_range) - torch.log(mse_score)
|
||||
psnr = psnr_base_e * (10 / torch.log(torch.tensor(base)))
|
||||
return psnr
|
||||
|
||||
|
||||
def _gaussian_kernel(channel, kernel_size, sigma, device):
|
||||
def _gaussian(kernel_size, sigma, device):
|
||||
gauss = torch.arange(
|
||||
start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2,
|
||||
step=1,
|
||||
dtype=torch.float32,
|
||||
device=device
|
||||
)
|
||||
gauss = torch.exp(-gauss.pow(2) / (2 * pow(sigma, 2)))
|
||||
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)
|
||||
|
||||
gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], device)
|
||||
gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], device)
|
||||
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y)
|
||||
|
||||
return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])
|
||||
|
||||
|
||||
def ssim(
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
kernel_size: Sequence[int] = (11, 11),
|
||||
sigma: Sequence[float] = (1.5, 1.5),
|
||||
reduction: str = "elementwise_mean",
|
||||
data_range: float = None,
|
||||
k1: float = 0.01,
|
||||
k2: float = 0.03
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes Structual Similarity Index Measure
|
||||
|
||||
Args:
|
||||
pred: estimated image
|
||||
target: ground truth image
|
||||
kernel_size: size of the gaussian kernel (default: (11, 11))
|
||||
sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5))
|
||||
reduction: a method to reduce metric score over labels.
|
||||
|
||||
- ``'elementwise_mean'``: takes the mean (default)
|
||||
- ``'sum'``: takes the sum
|
||||
- ``'none'``: no reduction will be applied
|
||||
|
||||
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
|
||||
k1: Parameter of SSIM. Default: 0.01
|
||||
k2: Parameter of SSIM. Default: 0.03
|
||||
|
||||
Return:
|
||||
Tensor with SSIM score
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.rand([16, 1, 16, 16])
|
||||
>>> target = pred * 0.75
|
||||
>>> ssim(pred, target)
|
||||
tensor(0.9219)
|
||||
|
||||
"""
|
||||
if pred.dtype != target.dtype:
|
||||
raise TypeError(
|
||||
"Expected `pred` and `target` to have the same data type."
|
||||
f" Got pred: {pred.dtype} and target: {target.dtype}."
|
||||
)
|
||||
|
||||
if pred.shape != target.shape:
|
||||
raise ValueError(
|
||||
"Expected `pred` and `target` to have the same shape."
|
||||
f" Got pred: {pred.shape} and target: {target.shape}."
|
||||
)
|
||||
|
||||
if len(pred.shape) != 4 or len(target.shape) != 4:
|
||||
raise ValueError(
|
||||
"Expected `pred` and `target` to have BxCxHxW shape."
|
||||
f" Got pred: {pred.shape} and target: {target.shape}."
|
||||
)
|
||||
|
||||
if len(kernel_size) != 2 or len(sigma) != 2:
|
||||
raise ValueError(
|
||||
"Expected `kernel_size` and `sigma` to have the length of two."
|
||||
f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}."
|
||||
)
|
||||
|
||||
if any(x % 2 == 0 or x <= 0 for x in kernel_size):
|
||||
raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.")
|
||||
|
||||
if any(y <= 0 for y in sigma):
|
||||
raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.")
|
||||
|
||||
if data_range is None:
|
||||
data_range = max(pred.max() - pred.min(), target.max() - target.min())
|
||||
|
||||
C1 = pow(k1 * data_range, 2)
|
||||
C2 = pow(k2 * data_range, 2)
|
||||
device = pred.device
|
||||
|
||||
channel = pred.size(1)
|
||||
kernel = _gaussian_kernel(channel, kernel_size, sigma, device)
|
||||
|
||||
# Concatenate
|
||||
# pred for mu_pred
|
||||
# target for mu_target
|
||||
# pred * pred for sigma_pred
|
||||
# target * target for sigma_target
|
||||
# pred * target for sigma_pred_target
|
||||
input_list = torch.cat([pred, target, pred * pred, target * target, pred * target]) # (5 * B, C, H, W)
|
||||
outputs = F.conv2d(input_list, kernel, groups=channel)
|
||||
output_list = [outputs[x * pred.size(0): (x + 1) * pred.size(0)] for x in range(len(outputs))]
|
||||
|
||||
mu_pred_sq = output_list[0].pow(2)
|
||||
mu_target_sq = output_list[1].pow(2)
|
||||
mu_pred_target = output_list[0] * output_list[1]
|
||||
|
||||
sigma_pred_sq = output_list[2] - mu_pred_sq
|
||||
sigma_target_sq = output_list[3] - mu_target_sq
|
||||
sigma_pred_target = output_list[4] - mu_pred_target
|
||||
|
||||
UPPER = 2 * sigma_pred_target + C2
|
||||
LOWER = sigma_pred_sq + sigma_target_sq + C2
|
||||
|
||||
ssim_idx = ((2 * mu_pred_target + C1) * UPPER) / ((mu_pred_sq + mu_target_sq + C1) * LOWER)
|
||||
|
||||
return reduce(ssim_idx, reduction)
|
|
@ -1,46 +0,0 @@
|
|||
import torch
|
||||
|
||||
|
||||
def embedding_similarity(
|
||||
batch: torch.Tensor,
|
||||
similarity: str = 'cosine',
|
||||
reduction: str = 'none',
|
||||
zero_diagonal: bool = True
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Computes representation similarity
|
||||
|
||||
Example:
|
||||
|
||||
>>> embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]])
|
||||
>>> embedding_similarity(embeddings)
|
||||
tensor([[0.0000, 1.0000, 0.9759],
|
||||
[1.0000, 0.0000, 0.9759],
|
||||
[0.9759, 0.9759, 0.0000]])
|
||||
|
||||
Args:
|
||||
batch: (batch, dim)
|
||||
similarity: 'dot' or 'cosine'
|
||||
reduction: 'none', 'sum', 'mean' (all along dim -1)
|
||||
zero_diagonal: if True, the diagonals are set to zero
|
||||
|
||||
Return:
|
||||
A square matrix (batch, batch) with the similarity scores between all elements
|
||||
If sum or mean are used, then returns (b, 1) with the reduced value for each row
|
||||
"""
|
||||
if similarity == 'cosine':
|
||||
norm = torch.norm(batch, p=2, dim=1)
|
||||
batch = batch / norm.unsqueeze(1)
|
||||
|
||||
sqr_mtx = batch.mm(batch.transpose(1, 0))
|
||||
|
||||
if zero_diagonal:
|
||||
sqr_mtx = sqr_mtx.fill_diagonal_(0)
|
||||
|
||||
if reduction == 'mean':
|
||||
sqr_mtx = sqr_mtx.mean(dim=-1)
|
||||
|
||||
if reduction == 'sum':
|
||||
sqr_mtx = sqr_mtx.sum(dim=-1)
|
||||
|
||||
return sqr_mtx
|
|
@ -1,262 +1,220 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Mapping, Optional, Sequence
|
||||
import numbers
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections import namedtuple
|
||||
from copy import deepcopy
|
||||
|
||||
import os
|
||||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
from pytorch_lightning.metrics.converters import (
|
||||
at_least_1d,
|
||||
gather_all_tensors_if_available,
|
||||
convert_to_tensor,
|
||||
convert_to_numpy,
|
||||
)
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.utilities.distributed import gather_all_tensors_if_available
|
||||
from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum
|
||||
|
||||
|
||||
class Metric(DeviceDtypeModuleMixin, nn.Module, ABC):
|
||||
class Metric(nn.Module, ABC):
|
||||
"""
|
||||
Abstract base class for metric implementation.
|
||||
Base class for all metrics present in the Metrics API.
|
||||
|
||||
Should be used to implement metrics that
|
||||
Implements ``add_state()``, ``forward()``, ``reset()`` and a few other things to
|
||||
handle distributed synchronization and per step metric computation.
|
||||
|
||||
1. Return multiple Outputs
|
||||
2. Handle their own DDP sync
|
||||
Override ``update()`` and ``compute()`` functions to implement your own metric. Use
|
||||
``add_state()`` to register metric state variables which keep track of state on each
|
||||
call of ``update()`` and are synchronized across processes when ``compute()`` is called.
|
||||
|
||||
Metric hooks that can be implemented are
|
||||
Note:
|
||||
Metric state variables can either be ``torch.Tensors`` or an empty list which can we used
|
||||
to store `torch.Tensors``.
|
||||
|
||||
* input_convert: pre-forward hook that takes care of input conversion
|
||||
* output_convert: post-forward hook that takes care of output convertion
|
||||
* ddp_reduce: implementation of ddp sync + aggregation, default is ddp_sync + aggregate
|
||||
* compute: post-ddp sync for additional metric computations
|
||||
|
||||
``ddp_reduce`` by default calls the following methods, which can also be overwritten if necessary.
|
||||
|
||||
* ddp_sync: implements how values should be synced across ddp-processes. Defaults to gather all.
|
||||
* aggregate: implement how values should be aggregated (defaults to mean).
|
||||
|
||||
Call order
|
||||
|
||||
input_convert -> forward -> output_convert -> ddp_reduce (per default being ddp_sync -> aggregate) -> compute
|
||||
Note:
|
||||
Different metrics only override ``update()`` and not ``forward()``. A call to ``update()``
|
||||
is valid, but it won't return the metric value at the current step. A call to ``forward()``
|
||||
automatically calls ``update()`` and also return the metric value at the current step.
|
||||
|
||||
Args:
|
||||
compute_on_step:
|
||||
Forward only calls ``update()`` and return None if this is set to False. default: True
|
||||
ddp_sync_on_step:
|
||||
Synchronize metric state across processes at each ``forward()``
|
||||
before returning the value at the step. default: False
|
||||
process_group:
|
||||
Specify the process group on which synchronization is called. default: None (which selects the entire world)
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, reduce_group: Optional[Any] = None):
|
||||
"""
|
||||
Args:
|
||||
name: the metric's name
|
||||
reduce_group: the process group for DDP reduces (only needed for DDP training).
|
||||
Defaults to all processes (world)
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
compute_on_step: bool = True,
|
||||
ddp_sync_on_step: bool = False,
|
||||
process_group: Optional[Any] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self._dtype = torch.get_default_dtype()
|
||||
self._device = torch.device("cpu")
|
||||
|
||||
self.reduce_group = reduce_group
|
||||
self.ddp_sync_on_step = ddp_sync_on_step
|
||||
self.compute_on_step = compute_on_step
|
||||
self.process_group = process_group
|
||||
self._to_sync = True
|
||||
|
||||
# Buffer for holding aggregated state after each batch
|
||||
self._step_vals = []
|
||||
self.update = self._wrap_update(self.update)
|
||||
self.compute = self._wrap_compute(self.compute)
|
||||
self._computed = None
|
||||
|
||||
# Register hooks
|
||||
self.register_forward_pre_hook(self.input_convert)
|
||||
self.register_forward_hook(self.output_convert)
|
||||
self.register_forward_hook(self.ddp_reduce)
|
||||
self.register_forward_hook(self.compute)
|
||||
# initialize state
|
||||
self._reductions = {}
|
||||
self._defaults = {}
|
||||
|
||||
@staticmethod
|
||||
def input_convert(self, data: Any):
|
||||
def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None):
|
||||
"""
|
||||
Implement how the inputs should be casted before calling forward
|
||||
Adds metric state variable. Only used by subclasses.
|
||||
|
||||
Args:
|
||||
data: input to forward method
|
||||
name: The name of the state variable. The variable will then be accessible at ``self.name``.
|
||||
default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be
|
||||
reset to this value when ``self.reset()`` is called.
|
||||
dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode.
|
||||
If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``,
|
||||
and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom
|
||||
function in this parameter.
|
||||
|
||||
Note:
|
||||
Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes.
|
||||
However, there won't be any reduction function applied to the synchronized metric state.
|
||||
|
||||
The metric states would be synced as follows
|
||||
|
||||
- If the metric state is ``torch.Tensor``, the synced value will be a stacked ``torch.Tensor`` across
|
||||
the process dimension if the metric state was a ``torch.Tensor``. The original ``torch.Tensor`` metric
|
||||
state retains dimension and hence the synchronized output will be of shape ``(num_process, ...)``.
|
||||
|
||||
- If the metric state is a ``list``, the synced value will be a ``list`` containing the
|
||||
combined elements from all processes.
|
||||
|
||||
Note:
|
||||
When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow
|
||||
the format discussed in the above note.
|
||||
|
||||
Returns:
|
||||
casted data
|
||||
"""
|
||||
return data
|
||||
if not isinstance(default, torch.Tensor) or (isinstance(default, list) and len(default) != 0):
|
||||
raise ValueError(
|
||||
"state variable must be a tensor or any empty list (where you can append tensors)"
|
||||
)
|
||||
|
||||
if dist_reduce_fx == "sum":
|
||||
dist_reduce_fx = dim_zero_sum
|
||||
elif dist_reduce_fx == "mean":
|
||||
dist_reduce_fx = dim_zero_mean
|
||||
elif dist_reduce_fx == "cat":
|
||||
dist_reduce_fx = dim_zero_cat
|
||||
elif dist_reduce_fx is not None and not isinstance(dist_reduce_fx, Callable):
|
||||
raise ValueError(
|
||||
"`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]"
|
||||
)
|
||||
|
||||
if isinstance(default, torch.Tensor):
|
||||
self.register_buffer(name, default)
|
||||
else:
|
||||
setattr(self, name, default)
|
||||
|
||||
self._defaults[name] = deepcopy(default)
|
||||
self._reductions[name] = dist_reduce_fx
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs):
|
||||
"""
|
||||
Implements the actual metric computation.
|
||||
|
||||
Returns:
|
||||
metric value or metric state
|
||||
|
||||
Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
# add current step
|
||||
self.update(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def output_convert(self, data: Any, output: Any):
|
||||
if self.compute_on_step:
|
||||
self._to_sync = self.ddp_sync_on_step
|
||||
|
||||
# save context before switch
|
||||
self._cache = {attr: getattr(self, attr) for attr in self._defaults.keys()}
|
||||
|
||||
# call reset, update, compute, on single batch
|
||||
self.reset()
|
||||
self.update(*args, **kwargs)
|
||||
result = self.compute()
|
||||
|
||||
# restore context
|
||||
for attr, val in self._cache.items():
|
||||
setattr(self, attr, val)
|
||||
self._to_sync = True
|
||||
self._computed = None
|
||||
|
||||
return result
|
||||
|
||||
def _sync_dist(self):
|
||||
input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()}
|
||||
output_dict = apply_to_collection(
|
||||
input_dict,
|
||||
torch.Tensor,
|
||||
gather_all_tensors_if_available,
|
||||
group=self.process_group,
|
||||
)
|
||||
|
||||
for attr, reduction_fn in self._reductions.items():
|
||||
# pre-processing ops (stack or flatten for inputs)
|
||||
if isinstance(output_dict[attr][0], torch.Tensor):
|
||||
output_dict[attr] = torch.stack(output_dict[attr])
|
||||
elif isinstance(output_dict[attr][0], list):
|
||||
output_dict[attr] = _flatten(output_dict[attr])
|
||||
|
||||
assert isinstance(reduction_fn, (Callable, None))
|
||||
reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr]
|
||||
setattr(self, attr, reduced)
|
||||
|
||||
def _wrap_update(self, update):
|
||||
@functools.wraps(update)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
self._computed = None
|
||||
return update(*args, **kwargs)
|
||||
return wrapped_func
|
||||
|
||||
def _wrap_compute(self, compute):
|
||||
@functools.wraps(compute)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
# return cached value
|
||||
if self._computed is not None:
|
||||
return self._computed
|
||||
|
||||
if (
|
||||
self._to_sync
|
||||
and torch.distributed.is_available() # noqa: W503
|
||||
and torch.distributed.is_initialized() # noqa: W503
|
||||
):
|
||||
self._sync_dist()
|
||||
|
||||
self._computed = compute(*args, **kwargs)
|
||||
self.reset()
|
||||
|
||||
return self._computed
|
||||
|
||||
return wrapped_func
|
||||
|
||||
@abstractmethod
|
||||
def update(self) -> None: # pylint: disable=E0202
|
||||
"""
|
||||
Implement how outputs from forward should be casted
|
||||
|
||||
Args:
|
||||
data: input to forward method
|
||||
output: output from forward method
|
||||
|
||||
Returns:
|
||||
casted outputs
|
||||
Override this method to update the state variables of your metric class.
|
||||
"""
|
||||
return apply_to_collection(output, (torch.Tensor, np.ndarray), at_least_1d)
|
||||
pass
|
||||
|
||||
def ddp_sync(self, tensor: Any):
|
||||
@abstractmethod
|
||||
def compute(self): # pylint: disable=E0202
|
||||
"""
|
||||
Implement how the outputs from forward should be synced
|
||||
(per default just gathers all of them and adds them to self._step_vals)
|
||||
|
||||
Args:
|
||||
tensor: tensor to sync
|
||||
|
||||
Returns:
|
||||
synced output
|
||||
|
||||
Override this method to compute the final metric value from state variables
|
||||
synchronized across the distributed backend.
|
||||
"""
|
||||
gathered_tensors = apply_to_collection(tensor, torch.Tensor, gather_all_tensors_if_available, self.reduce_group)
|
||||
return gathered_tensors
|
||||
|
||||
@staticmethod
|
||||
def ddp_reduce(self, data: Any, output: Any):
|
||||
"""
|
||||
Implement how the outputs from forward should be synced and reduced across nodes
|
||||
|
||||
Args:
|
||||
data: input to forward method
|
||||
output: output from the `output_convert` hook
|
||||
|
||||
Returns:
|
||||
synced output
|
||||
|
||||
"""
|
||||
synced = self.ddp_sync(output)
|
||||
agg_val = self.aggregate(synced)
|
||||
self._step_vals.append(agg_val)
|
||||
return agg_val
|
||||
|
||||
def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Implement aggregation of values on the same device
|
||||
|
||||
Args:
|
||||
tensors: the values to be aggregated
|
||||
|
||||
Returns:
|
||||
aggregated values
|
||||
|
||||
"""
|
||||
# single tensor
|
||||
if len(tensors) == 1:
|
||||
tensors = tensors[0]
|
||||
if isinstance(tensors, Mapping):
|
||||
return {k: _stack_and_agg(tensors[k]) for k in tensors.keys()}
|
||||
if isinstance(tensors, list):
|
||||
return _stack_and_agg(tensors)
|
||||
if isinstance(tensors, tuple):
|
||||
return tensors
|
||||
if isinstance(tensors, torch.Tensor):
|
||||
return _stack_and_agg(tensors)
|
||||
|
||||
# multiple tensors (from aggregation over batches)
|
||||
if isinstance(tensors[0], Mapping):
|
||||
return {k: torch.stack([tensor[k] for tensor in tensors]).sum(0) for k in tensors[0].keys()}
|
||||
if isinstance(tensors[0], Sequence):
|
||||
return tuple([torch.stack(tmp).sum(0) for tmp in zip(*tensors)])
|
||||
if isinstance(tensors[0], torch.Tensor):
|
||||
return torch.stack(tensors).sum(0)
|
||||
|
||||
raise TypeError("unknown metric value format to aggregate")
|
||||
|
||||
@staticmethod
|
||||
def compute(self, data: Any, output: Any):
|
||||
"""
|
||||
Implement additionally metric computations to be done after the aggregation
|
||||
|
||||
Args:
|
||||
data: input to forward method
|
||||
output: output from the `aggregate` hook
|
||||
|
||||
Returns:
|
||||
final metric value
|
||||
|
||||
"""
|
||||
return output
|
||||
|
||||
@property
|
||||
def aggregated(self) -> torch.Tensor:
|
||||
aggr = self.aggregate(*self._step_vals if len(self._step_vals) > 1 else self._step_vals)
|
||||
self.reset()
|
||||
return self.compute(self, None, aggr)
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
self._step_vals = []
|
||||
|
||||
|
||||
def _stack_and_agg(tensors):
|
||||
""" Utility function for stacking and aggregating tensors """
|
||||
if isinstance(tensors, list):
|
||||
return torch.sum(torch.stack([t for t in tensors]), 0)
|
||||
return tensors.squeeze() if tensors.numel() == 1 else tensors
|
||||
|
||||
|
||||
class TensorMetric(Metric):
|
||||
"""
|
||||
Base class for metric implementation operating directly on tensors.
|
||||
All inputs and outputs will be casted to tensors if necessary.
|
||||
Already handles DDP sync and input/output conversions.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def input_convert(self, data: Any):
|
||||
data = apply_to_collection(
|
||||
data, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device
|
||||
)
|
||||
return super(TensorMetric, self).input_convert(self, data)
|
||||
|
||||
@staticmethod
|
||||
def output_convert(self, data: Any, output: Any):
|
||||
|
||||
output = apply_to_collection(
|
||||
output, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device
|
||||
)
|
||||
return super(TensorMetric, self).output_convert(self, data, output)
|
||||
|
||||
|
||||
class NumpyMetric(Metric):
|
||||
"""
|
||||
Base class for metric implementation operating on numpy arrays.
|
||||
All inputs will be casted to numpy if necessary and all outputs will
|
||||
be casted to tensors if necessary.
|
||||
Already handles DDP sync and input/output conversions.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def input_convert(self, data: Any):
|
||||
data = apply_to_collection(data, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)
|
||||
return super(NumpyMetric, self).input_convert(self, data)
|
||||
|
||||
@staticmethod
|
||||
def output_convert(self, data: Any, output: Any):
|
||||
output = apply_to_collection(
|
||||
output, (torch.Tensor, np.ndarray, numbers.Number), convert_to_tensor, self.dtype, self.device
|
||||
)
|
||||
|
||||
return super(NumpyMetric, self).output_convert(self, data, output)
|
||||
"""
|
||||
This method automatically resets the metric state variables to their default value.
|
||||
"""
|
||||
for attr, default in self._defaults.items():
|
||||
current_val = getattr(self, attr)
|
||||
if isinstance(current_val, torch.Tensor):
|
||||
setattr(self, attr, deepcopy(default).to(current_val.device))
|
||||
else:
|
||||
setattr(self, attr, deepcopy(default))
|
||||
|
|
|
@ -1,60 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.nlp import bleu_score
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
|
||||
|
||||
class BLEUScore(Metric):
|
||||
"""
|
||||
Calculate BLEU score of machine translated text with one or more references.
|
||||
|
||||
Example:
|
||||
|
||||
>>> translate_corpus = ['the cat is on the mat'.split()]
|
||||
>>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
|
||||
>>> metric = BLEUScore()
|
||||
>>> metric(translate_corpus, reference_corpus)
|
||||
tensor(0.7598)
|
||||
"""
|
||||
|
||||
def __init__(self, n_gram: int = 4, smooth: bool = False):
|
||||
"""
|
||||
Args:
|
||||
n_gram: Gram value ranged from 1 to 4 (Default 4)
|
||||
smooth: Whether or not to apply smoothing – Lin et al. 2004
|
||||
"""
|
||||
super().__init__(name="bleu")
|
||||
self.n_gram = n_gram
|
||||
self.smooth = smooth
|
||||
|
||||
def forward(self, translate_corpus: list, reference_corpus: list) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
translate_corpus: An iterable of machine translated corpus
|
||||
reference_corpus: An iterable of iterables of reference corpus
|
||||
|
||||
Return:
|
||||
torch.Tensor: BLEU Score
|
||||
"""
|
||||
return bleu_score(
|
||||
translate_corpus=translate_corpus,
|
||||
reference_corpus=reference_corpus,
|
||||
n_gram=self.n_gram,
|
||||
smooth=self.smooth,
|
||||
).to(self.device, self.dtype)
|
|
@ -1,361 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Sequence, Any
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.regression import (
|
||||
mae,
|
||||
mse,
|
||||
psnr,
|
||||
rmse,
|
||||
rmsle,
|
||||
ssim
|
||||
)
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
|
||||
|
||||
class MSE(Metric):
|
||||
"""
|
||||
Computes the mean squared loss.
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0., 1, 2, 3])
|
||||
>>> target = torch.tensor([0., 1, 2, 2])
|
||||
>>> metric = MSE()
|
||||
>>> metric(pred, target)
|
||||
tensor(0.2500)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduction: str = 'elementwise_mean',
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
reduction: a method to reduce metric score over labels.
|
||||
|
||||
- ``'elementwise_mean'``: takes the mean (default)
|
||||
- ``'sum'``: takes the sum
|
||||
- ``'none'``: no reduction will be applied
|
||||
"""
|
||||
super().__init__(name='mse')
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: ground truth labels
|
||||
|
||||
Return:
|
||||
A Tensor with the mse loss.
|
||||
"""
|
||||
return mse(pred, target, return_state=True)
|
||||
|
||||
@staticmethod
|
||||
def compute(self, data: Any, output: Any):
|
||||
sse, n = output['squared_error'], output['n_observations']
|
||||
return sse / n
|
||||
|
||||
|
||||
class RMSE(Metric):
|
||||
"""
|
||||
Computes the root mean squared loss.
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0., 1, 2, 3])
|
||||
>>> target = torch.tensor([0., 1, 2, 2])
|
||||
>>> metric = RMSE()
|
||||
>>> metric(pred, target)
|
||||
tensor(0.5000)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduction: str = 'elementwise_mean',
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
reduction: a method to reduce metric score over labels.
|
||||
|
||||
- ``'elementwise_mean'``: takes the mean (default)
|
||||
- ``'sum'``: takes the sum
|
||||
- ``'none'``: no reduction will be applied
|
||||
"""
|
||||
super().__init__(name='rmse')
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: ground truth labels
|
||||
|
||||
Return:
|
||||
A Tensor with the rmse loss.
|
||||
"""
|
||||
return rmse(pred, target, reduction='none', return_state=True)
|
||||
|
||||
@staticmethod
|
||||
def compute(self, data: Any, output: Any):
|
||||
""" Squaring needs to happend after ddp sync """
|
||||
sse, n = output['squared_error'], output['n_observations']
|
||||
return torch.sqrt(sse / n)
|
||||
|
||||
|
||||
class MAE(Metric):
|
||||
"""
|
||||
Computes the mean absolute loss or L1-loss.
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0., 1, 2, 3])
|
||||
>>> target = torch.tensor([0., 1, 2, 2])
|
||||
>>> metric = MAE()
|
||||
>>> metric(pred, target)
|
||||
tensor(0.2500)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduction: str = 'elementwise_mean',
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
reduction: a method to reduce metric score over labels.
|
||||
|
||||
- ``'elementwise_mean'``: takes the mean (default)
|
||||
- ``'sum'``: takes the sum
|
||||
- ``'none'``: no reduction will be applied
|
||||
"""
|
||||
super().__init__(name='mae')
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: ground truth labels
|
||||
|
||||
Return:
|
||||
A Tensor with the mae loss.
|
||||
"""
|
||||
return mae(pred, target, return_state=True)
|
||||
|
||||
@staticmethod
|
||||
def compute(self, data: Any, output: Any):
|
||||
sae, n = output['absolute_error'], output['n_observations']
|
||||
return sae / n
|
||||
|
||||
|
||||
class RMSLE(Metric):
|
||||
"""
|
||||
Computes the root mean squared log loss.
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([0., 1, 2, 3])
|
||||
>>> target = torch.tensor([0., 1, 2, 2])
|
||||
>>> metric = RMSLE()
|
||||
>>> metric(pred, target)
|
||||
tensor(0.1438)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reduction: str = 'elementwise_mean',
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
reduction: a method to reduce metric score over labels.
|
||||
|
||||
- ``'elementwise_mean'``: takes the mean (default)
|
||||
- ``'sum'``: takes the sum
|
||||
- ``'none'``: no reduction will be applied
|
||||
"""
|
||||
super().__init__(name='rmsle')
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: ground truth labels
|
||||
|
||||
Return:
|
||||
A Tensor with the rmsle loss.
|
||||
"""
|
||||
return mse(torch.log(pred + 1), torch.log(target + 1),
|
||||
self.reduction, return_state=True)
|
||||
|
||||
@staticmethod
|
||||
def compute(self, data: Any, output: Any):
|
||||
""" Squaring needs to happend after ddp sync """
|
||||
sse, n = output['squared_error'], output['n_observations']
|
||||
return torch.sqrt(sse / n)
|
||||
|
||||
|
||||
class PSNR(Metric):
|
||||
"""
|
||||
Computes the peak signal-to-noise ratio
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
|
||||
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
|
||||
>>> metric = PSNR()
|
||||
>>> metric(pred, target)
|
||||
tensor(2.5527)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_range: float = None,
|
||||
base: int = 10,
|
||||
reduction: str = 'elementwise_mean'
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
data_range: the range of the data. If None, it is determined from the data (max - min)
|
||||
base: a base of a logarithm to use (default: 10)
|
||||
reduction: a method to reduce metric score over labels.
|
||||
|
||||
- ``'elementwise_mean'``: takes the mean (default)
|
||||
- ``'sum'``: takes the sum
|
||||
- ``'none'``: no reduction will be applied
|
||||
"""
|
||||
super().__init__(name='psnr')
|
||||
self.data_range = data_range
|
||||
self.base = float(base)
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: predicted labels
|
||||
target: ground truth labels
|
||||
|
||||
Return:
|
||||
A Tensor with psnr score.
|
||||
"""
|
||||
return psnr(pred, target, self.data_range, self.base, self.reduction, return_state=True)
|
||||
|
||||
def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor:
|
||||
""" Special aggregation function as the data range needs to be correctly synced """
|
||||
if len(tensors) == 1:
|
||||
tensors = tensors[0]
|
||||
output = {'data_range': torch.stack([t for t in tensors['data_range']]).max()}
|
||||
output.update({k: torch.stack([t for t in tensors[k]]).sum(0) for k in tensors.keys() if k != 'data_range'})
|
||||
return output
|
||||
|
||||
output = {'data_range': torch.stack([tensor['data_range'] for tensor in tensors]).max()}
|
||||
output.update({k: torch.stack([tensor[k] for tensor in tensors]).sum(0) for k in tensors[0].keys() if k != 'data_range'})
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def compute(self, data: Any, output: Any):
|
||||
"""
|
||||
Compute final value based on the synced data_range, sum of squared errors
|
||||
and number of samples.
|
||||
|
||||
Args:
|
||||
data: input to forward method
|
||||
output: output from the `aggregate` hook
|
||||
|
||||
Returns:
|
||||
final metric value
|
||||
|
||||
"""
|
||||
sse, n, data_range = output['sum_squared_error'], output['n_obs'], output['data_range']
|
||||
psnr_base_e = 2 * torch.log(data_range) - torch.log(sse / n)
|
||||
psnr = psnr_base_e * (10 / torch.log(torch.tensor(self.base)))
|
||||
return psnr
|
||||
|
||||
|
||||
class SSIM(Metric):
|
||||
"""
|
||||
Computes Structual Similarity Index Measure
|
||||
|
||||
Example:
|
||||
|
||||
>>> pred = torch.rand([16, 1, 16, 16])
|
||||
>>> target = pred * 0.75
|
||||
>>> metric = SSIM()
|
||||
>>> metric(pred, target)
|
||||
tensor(0.9219)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size: Sequence[int] = (11, 11),
|
||||
sigma: Sequence[float] = (1.5, 1.5),
|
||||
reduction: str = "elementwise_mean",
|
||||
data_range: float = None,
|
||||
k1: float = 0.01,
|
||||
k2: float = 0.03
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
kernel_size: Size of the gaussian kernel (default: (11, 11))
|
||||
sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5))
|
||||
reduction: a method to reduce metric score over labels.
|
||||
|
||||
- ``'elementwise_mean'``: takes the mean (default)
|
||||
- ``'sum'``: takes the sum
|
||||
- ``'none'``: no reduction will be applied
|
||||
|
||||
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
|
||||
k1: Parameter of SSIM. Default: 0.01
|
||||
k2: Parameter of SSIM. Default: 0.03
|
||||
"""
|
||||
super().__init__(name="ssim")
|
||||
self.kernel_size = kernel_size
|
||||
self.sigma = sigma
|
||||
self.reduction = reduction
|
||||
self.data_range = data_range
|
||||
self.k1 = k1
|
||||
self.k2 = k2
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
pred: Estimated image
|
||||
target: Ground truth image
|
||||
|
||||
Return:
|
||||
A Tensor with SSIM score.
|
||||
"""
|
||||
return ssim(pred, target, self.kernel_size, self.sigma, self.reduction, self.data_range, self.k1, self.k2)
|
|
@ -0,0 +1,3 @@
|
|||
from pytorch_lightning.metrics.regression.mean_squared_error import MeanSquaredError
|
||||
from pytorch_lightning.metrics.regression.mean_absolute_error import MeanAbsoluteError
|
||||
from pytorch_lightning.metrics.regression.mean_squared_log_error import MeanSquaredLogError
|
|
@ -0,0 +1,54 @@
|
|||
import torch
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
|
||||
|
||||
class MeanAbsoluteError(Metric):
|
||||
"""
|
||||
Computes mean absolute error.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from pytorch_lightning.metrics import MeanAbsoluteError
|
||||
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
|
||||
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
|
||||
>>> mean_absolute_error = MeanAbsoluteError()
|
||||
>>> mean_absolute_error(preds, target)
|
||||
tensor(0.5000)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compute_on_step: bool = True,
|
||||
ddp_sync_on_step: bool = False,
|
||||
process_group: Optional[Any] = None,
|
||||
):
|
||||
super().__init__(
|
||||
compute_on_step=compute_on_step,
|
||||
ddp_sync_on_step=ddp_sync_on_step,
|
||||
process_group=process_group,
|
||||
)
|
||||
|
||||
self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
Args:
|
||||
preds: Predictions from model
|
||||
target: Ground truth values
|
||||
"""
|
||||
assert preds.shape == target.shape
|
||||
abs_error = torch.abs(preds - target)
|
||||
|
||||
self.sum_abs_error += torch.sum(abs_error)
|
||||
self.total += target.numel()
|
||||
|
||||
def compute(self):
|
||||
"""
|
||||
Computes mean absolute error over state.
|
||||
"""
|
||||
return self.sum_abs_error / self.total
|
|
@ -0,0 +1,55 @@
|
|||
import torch
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
|
||||
|
||||
class MeanSquaredError(Metric):
|
||||
"""
|
||||
Computes mean squared error.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from pytorch_lightning.metrics import MeanSquaredError
|
||||
>>> target = torch.tensor([2.5, 5.0, 4.0, 8.0])
|
||||
>>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0])
|
||||
>>> mean_squared_error = MeanSquaredError()
|
||||
>>> mean_squared_error(preds, target)
|
||||
tensor(0.8750)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compute_on_step: bool = True,
|
||||
ddp_sync_on_step: bool = False,
|
||||
process_group: Optional[Any] = None,
|
||||
):
|
||||
super().__init__(
|
||||
compute_on_step=compute_on_step,
|
||||
ddp_sync_on_step=ddp_sync_on_step,
|
||||
process_group=process_group,
|
||||
)
|
||||
|
||||
self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
Args:
|
||||
preds: Predictions from model
|
||||
target: Ground truth values
|
||||
"""
|
||||
assert preds.shape == target.shape
|
||||
squared_error = torch.pow(preds - target, 2)
|
||||
|
||||
self.sum_squared_error += torch.sum(squared_error)
|
||||
self.total += target.numel()
|
||||
|
||||
def compute(self):
|
||||
"""
|
||||
Computes mean squared error over state.
|
||||
"""
|
||||
return self.sum_squared_error / self.total
|
|
@ -0,0 +1,55 @@
|
|||
import torch
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
|
||||
|
||||
class MeanSquaredLogError(Metric):
|
||||
"""
|
||||
Computes mean squared logarithmic error.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from pytorch_lightning.metrics import MeanSquaredLogError
|
||||
>>> target = torch.tensor([2.5, 5, 4, 8])
|
||||
>>> preds = torch.tensor([3, 5, 2.5, 7])
|
||||
>>> mean_squared_log_error = MeanSquaredLogError()
|
||||
>>> mean_squared_log_error(preds, target)
|
||||
tensor(0.0397)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compute_on_step: bool = True,
|
||||
ddp_sync_on_step: bool = False,
|
||||
process_group: Optional[Any] = None,
|
||||
):
|
||||
super().__init__(
|
||||
compute_on_step=compute_on_step,
|
||||
ddp_sync_on_step=ddp_sync_on_step,
|
||||
process_group=process_group,
|
||||
)
|
||||
|
||||
self.add_state("sum_squared_log_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
|
||||
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
Args:
|
||||
preds: Predictions from model
|
||||
target: Ground truth values
|
||||
"""
|
||||
assert preds.shape == target.shape
|
||||
squared_log_error = torch.pow(torch.log1p(preds) - torch.log1p(target), 2)
|
||||
|
||||
self.sum_squared_log_error += torch.sum(squared_log_error)
|
||||
self.total += target.numel()
|
||||
|
||||
def compute(self):
|
||||
"""
|
||||
Compute mean squared logarithmic error over state.
|
||||
"""
|
||||
return self.sum_squared_log_error / self.total
|
|
@ -1,85 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity
|
||||
from pytorch_lightning.metrics.metric import TensorMetric
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
|
||||
class EmbeddingSimilarity(TensorMetric):
|
||||
"""
|
||||
Computes similarity between embeddings
|
||||
|
||||
Example:
|
||||
>>> embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]])
|
||||
>>> embedding_similarity(embeddings)
|
||||
tensor([[0.0000, 1.0000, 0.9759],
|
||||
[1.0000, 0.0000, 0.9759],
|
||||
[0.9759, 0.9759, 0.0000]])
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
similarity: str = 'cosine',
|
||||
zero_diagonal: bool = True,
|
||||
reduction: str = 'mean',
|
||||
reduce_group: Any = None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
similarity: 'dot' or 'cosine'
|
||||
reduction: 'none', 'sum', 'mean' (all along dim -1)
|
||||
zero_diagonal: if True, the diagonals are set to zero
|
||||
reduce_group: the process group to reduce metric results from DDP
|
||||
|
||||
"""
|
||||
super().__init__(name='embedding_similarity',
|
||||
reduce_group=reduce_group)
|
||||
assert similarity in ('dot', 'cosine')
|
||||
self.similarity = similarity
|
||||
isinstance(zero_diagonal, bool)
|
||||
self.zero_diagonal = zero_diagonal
|
||||
assert reduction in ('none', 'sum', 'mean')
|
||||
self.reduction = reduction
|
||||
|
||||
rank_zero_warn('Please note that Metric `EmbeddingSimilarity` does not support aggregation.')
|
||||
|
||||
def forward(self, batch: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Actual metric computation
|
||||
|
||||
Args:
|
||||
batch: tensor containing embeddings with shape (batch_size, dim)
|
||||
|
||||
Return:
|
||||
A square matrix (batch, batch) with the similarity scores between all elements
|
||||
If sum or mean are used, then returns (b, 1) with the reduced value for each row
|
||||
"""
|
||||
return embedding_similarity(batch,
|
||||
similarity=self.similarity,
|
||||
zero_diagonal=self.zero_diagonal,
|
||||
reduction=self.reduction)
|
||||
|
||||
@staticmethod
|
||||
def ddp_reduce(self, data: Any, output: Any):
|
||||
""" reduction for this metric does not make sense """
|
||||
return output
|
||||
|
||||
@property
|
||||
def aggregated(self):
|
||||
raise ValueError('Metric `EmbeddingSimilarity` does not support aggregation.')
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,19 @@
|
|||
import torch
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
|
||||
def dim_zero_cat(x):
|
||||
return torch.cat(x, dim=0)
|
||||
|
||||
|
||||
def dim_zero_sum(x):
|
||||
return torch.sum(x, dim=0)
|
||||
|
||||
|
||||
def dim_zero_mean(x):
|
||||
return torch.mean(x, dim=0)
|
||||
|
||||
|
||||
def _flatten(x):
|
||||
return [item for sublist in x for item in sublist]
|
|
@ -16,7 +16,15 @@ import os
|
|||
import warnings
|
||||
from functools import wraps
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import _logger as log
|
||||
from typing import Union, Optional, Any
|
||||
|
||||
if torch.distributed.is_available():
|
||||
from torch.distributed import ReduceOp
|
||||
else:
|
||||
class ReduceOp:
|
||||
SUM = None
|
||||
|
||||
|
||||
def rank_zero_only(fn):
|
||||
|
@ -63,3 +71,71 @@ def find_free_network_port() -> int:
|
|||
port = s.getsockname()[1]
|
||||
s.close()
|
||||
return port
|
||||
|
||||
|
||||
def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None):
|
||||
"""
|
||||
Function to gather all tensors from several ddp processes onto a list that
|
||||
is broadcasted to all processes
|
||||
|
||||
Args:
|
||||
result: the value to sync
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
|
||||
Return:
|
||||
gathered_result: list with size equal to the process group where
|
||||
gathered_result[i] corresponds to result tensor from process i
|
||||
|
||||
"""
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
if group is None:
|
||||
group = torch.distributed.group.WORLD
|
||||
|
||||
world_size = torch.distributed.get_world_size(group)
|
||||
|
||||
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
|
||||
|
||||
# sync and broadcast all
|
||||
torch.distributed.barrier(group=group)
|
||||
torch.distributed.all_gather(gathered_result, result, group)
|
||||
|
||||
result = gathered_result
|
||||
return result
|
||||
|
||||
|
||||
def sync_ddp_if_available(
|
||||
result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Function to reduce the tensors from several ddp processes to one master process
|
||||
|
||||
Args:
|
||||
result: the value to sync and reduce (typically tensor or number)
|
||||
group: the process group to gather results from. Defaults to all processes (world)
|
||||
reduce_op: the reduction operation. Defaults to sum.
|
||||
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
|
||||
|
||||
Return:
|
||||
reduced value
|
||||
"""
|
||||
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
||||
divide_by_world_size = False
|
||||
|
||||
if group is None:
|
||||
group = torch.distributed.group.WORLD
|
||||
|
||||
if reduce_op is None:
|
||||
reduce_op = torch.distributed.ReduceOp.SUM
|
||||
elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
|
||||
reduce_op = torch.distributed.ReduceOp.SUM
|
||||
divide_by_world_size = True
|
||||
|
||||
# sync all processes before reduction
|
||||
torch.distributed.barrier(group=group)
|
||||
torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False)
|
||||
|
||||
if divide_by_world_size:
|
||||
result = result / torch.distributed.get_world_size(group)
|
||||
|
||||
return result
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
import os
|
||||
|
||||
from tests.metrics.utils import compute_batch, setup_ddp
|
||||
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE
|
||||
from tests.metrics.test_metric import Dummy
|
|
@ -0,0 +1,155 @@
|
|||
import os
|
||||
import pytest
|
||||
import torch
|
||||
import os
|
||||
import numpy as np
|
||||
from collections import namedtuple
|
||||
|
||||
from pytorch_lightning.metrics.classification.accuracy import Accuracy
|
||||
from sklearn.metrics import accuracy_score
|
||||
|
||||
from tests.metrics.utils import compute_batch, setup_ddp
|
||||
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
# global vars
|
||||
num_classes = 5
|
||||
threshold = 0.5
|
||||
extra_dim = 3
|
||||
|
||||
Input = namedtuple('Input', ["preds", "target"])
|
||||
|
||||
|
||||
def test_accuracy_invalid_shape():
|
||||
with pytest.raises(ValueError):
|
||||
acc = Accuracy()
|
||||
acc.update(preds=torch.rand(1), target=torch.rand(1, 2, 3))
|
||||
|
||||
|
||||
_binary_prob_inputs = Input(
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
|
||||
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))
|
||||
)
|
||||
|
||||
|
||||
def _binary_prob_sk_metric(preds, target):
|
||||
sk_preds = (preds.view(-1).numpy() >= threshold).astype(np.uint8)
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
|
||||
|
||||
|
||||
_binary_inputs = Input(
|
||||
preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE,)),
|
||||
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE,))
|
||||
)
|
||||
|
||||
|
||||
def _binary_sk_metric(preds, target):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
|
||||
|
||||
|
||||
_multilabel_prob_inputs = Input(
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes),
|
||||
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, num_classes))
|
||||
)
|
||||
|
||||
|
||||
def _multilabel_prob_sk_metric(preds, target):
|
||||
sk_preds = (preds.view(-1).numpy() >= threshold).astype(np.uint8)
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
|
||||
|
||||
|
||||
_multilabel_inputs = Input(
|
||||
preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, num_classes)),
|
||||
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, num_classes))
|
||||
)
|
||||
|
||||
|
||||
def _multilabel_sk_metric(preds, target):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
|
||||
|
||||
|
||||
_multiclass_prob_inputs = Input(
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes),
|
||||
target=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE))
|
||||
)
|
||||
|
||||
|
||||
def _multiclass_prob_sk_metric(preds, target):
|
||||
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
|
||||
|
||||
|
||||
_multiclass_inputs = Input(
|
||||
preds=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE)),
|
||||
target=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE))
|
||||
)
|
||||
|
||||
|
||||
def _multiclass_sk_metric(preds, target):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
|
||||
|
||||
|
||||
_multidim_multiclass_prob_inputs = Input(
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes, extra_dim),
|
||||
target=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE, extra_dim))
|
||||
)
|
||||
|
||||
|
||||
def _multidim_multiclass_prob_sk_metric(preds, target):
|
||||
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
|
||||
|
||||
|
||||
_multidim_multiclass_inputs = Input(
|
||||
preds=torch.randint(high=num_classes, size=(NUM_BATCHES, extra_dim, BATCH_SIZE)),
|
||||
target=torch.randint(high=num_classes, size=(NUM_BATCHES, extra_dim, BATCH_SIZE))
|
||||
)
|
||||
|
||||
|
||||
def _multidim_multiclass_sk_metric(preds, target):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
|
||||
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
@pytest.mark.parametrize("ddp_sync_on_step", [True, False])
|
||||
@pytest.mark.parametrize("preds, target, sk_metric", [
|
||||
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric),
|
||||
(_binary_inputs.preds, _binary_inputs.target, _binary_sk_metric),
|
||||
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _multilabel_prob_sk_metric),
|
||||
(_multilabel_inputs.preds, _multilabel_inputs.target, _multilabel_sk_metric),
|
||||
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _multiclass_prob_sk_metric),
|
||||
(_multiclass_inputs.preds, _multiclass_inputs.target, _multiclass_sk_metric),
|
||||
(
|
||||
_multidim_multiclass_prob_inputs.preds,
|
||||
_multidim_multiclass_prob_inputs.target,
|
||||
_multidim_multiclass_prob_sk_metric
|
||||
),
|
||||
(
|
||||
_multidim_multiclass_inputs.preds,
|
||||
_multidim_multiclass_inputs.target,
|
||||
_multidim_multiclass_sk_metric
|
||||
)
|
||||
])
|
||||
def test_accuracy(ddp, ddp_sync_on_step, preds, target, sk_metric):
|
||||
compute_batch(preds, target, Accuracy, sk_metric, ddp_sync_on_step, ddp, metric_args={"threshold": threshold})
|
|
@ -1,485 +0,0 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import (
|
||||
accuracy_score as sk_accuracy,
|
||||
jaccard_score as sk_jaccard_score,
|
||||
precision_score as sk_precision,
|
||||
recall_score as sk_recall,
|
||||
f1_score as sk_f1_score,
|
||||
fbeta_score as sk_fbeta_score,
|
||||
confusion_matrix as sk_confusion_matrix,
|
||||
roc_curve as sk_roc_curve,
|
||||
roc_auc_score as sk_roc_auc_score,
|
||||
precision_recall_curve as sk_precision_recall_curve
|
||||
)
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.metrics.functional.classification import (
|
||||
to_onehot,
|
||||
to_categorical,
|
||||
get_num_classes,
|
||||
stat_scores,
|
||||
stat_scores_multiple_classes,
|
||||
accuracy,
|
||||
confusion_matrix,
|
||||
precision,
|
||||
recall,
|
||||
fbeta_score,
|
||||
f1_score,
|
||||
_binary_clf_curve,
|
||||
dice_score,
|
||||
average_precision,
|
||||
auroc,
|
||||
precision_recall_curve,
|
||||
roc,
|
||||
auc,
|
||||
iou,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric', 'only_binary'], [
|
||||
pytest.param(sk_accuracy, accuracy, False, id='accuracy'),
|
||||
pytest.param(partial(sk_jaccard_score, average='macro'), iou, False, id='iou'),
|
||||
pytest.param(partial(sk_precision, average='micro'), precision, False, id='precision'),
|
||||
pytest.param(partial(sk_recall, average='micro'), recall, False, id='recall'),
|
||||
pytest.param(partial(sk_f1_score, average='micro'), f1_score, False, id='f1_score'),
|
||||
pytest.param(partial(sk_fbeta_score, average='micro', beta=2),
|
||||
partial(fbeta_score, beta=2), False, id='fbeta_score'),
|
||||
pytest.param(sk_confusion_matrix, confusion_matrix, False, id='confusion_matrix'),
|
||||
pytest.param(sk_roc_curve, roc, True, id='roc'),
|
||||
pytest.param(sk_precision_recall_curve, precision_recall_curve, True, id='precision_recall_curve'),
|
||||
pytest.param(sk_roc_auc_score, auroc, True, id='auroc')
|
||||
])
|
||||
def test_against_sklearn(sklearn_metric, torch_metric, only_binary):
|
||||
"""Compare PL metrics to sklearn version. """
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# for metrics with only_binary=False, we try out different combinations of number
|
||||
# of labels in pred and target (also test binary)
|
||||
# for metrics with only_binary=True, target is always binary and pred will be
|
||||
# (unnormalized) class probabilities
|
||||
class_comb = [(5, 2)] if only_binary else [(10, 10), (5, 10), (10, 5), (2, 2)]
|
||||
for n_cls_pred, n_cls_target in class_comb:
|
||||
pred = torch.randint(n_cls_pred, (300,), device=device)
|
||||
target = torch.randint(n_cls_target, (300,), device=device)
|
||||
|
||||
sk_score = sklearn_metric(target.cpu().detach().numpy(),
|
||||
pred.cpu().detach().numpy())
|
||||
pl_score = torch_metric(pred, target)
|
||||
|
||||
# if multi output
|
||||
if isinstance(sk_score, tuple):
|
||||
sk_score = [torch.tensor(sk_s.copy(), dtype=torch.float, device=device) for sk_s in sk_score]
|
||||
for sk_s, pl_s in zip(sk_score, pl_score):
|
||||
assert torch.allclose(sk_s, pl_s.float())
|
||||
else:
|
||||
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
|
||||
assert torch.allclose(sk_score, pl_score)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('class_reduction', ['micro', 'macro', 'weighted'])
|
||||
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
|
||||
pytest.param(sk_precision, precision, id='precision'),
|
||||
pytest.param(sk_recall, recall, id='recall'),
|
||||
pytest.param(sk_f1_score, f1_score, id='f1_score'),
|
||||
pytest.param(partial(sk_fbeta_score, beta=2), partial(fbeta_score, beta=2), id='fbeta_score')
|
||||
])
|
||||
def test_different_reduction_against_sklearn(class_reduction, sklearn_metric, torch_metric):
|
||||
""" Test metrics where the class_reduction parameter have a correponding
|
||||
value in sklearn """
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
pred = torch.randint(10, (300,), device=device)
|
||||
target = torch.randint(10, (300,), device=device)
|
||||
sk_score = sklearn_metric(target.cpu().detach().numpy(),
|
||||
pred.cpu().detach().numpy(),
|
||||
average=class_reduction)
|
||||
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
|
||||
pl_score = torch_metric(pred, target, class_reduction=class_reduction)
|
||||
assert torch.allclose(sk_score, pl_score)
|
||||
|
||||
|
||||
def test_onehot():
|
||||
test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
||||
expected = torch.stack([
|
||||
torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]),
|
||||
torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)])
|
||||
])
|
||||
|
||||
assert test_tensor.shape == (2, 5)
|
||||
assert expected.shape == (2, 10, 5)
|
||||
|
||||
onehot_classes = to_onehot(test_tensor, num_classes=10)
|
||||
onehot_no_classes = to_onehot(test_tensor)
|
||||
|
||||
assert torch.allclose(onehot_classes, onehot_no_classes)
|
||||
|
||||
assert onehot_classes.shape == expected.shape
|
||||
assert onehot_no_classes.shape == expected.shape
|
||||
|
||||
assert torch.allclose(expected.to(onehot_no_classes), onehot_no_classes)
|
||||
assert torch.allclose(expected.to(onehot_classes), onehot_classes)
|
||||
|
||||
|
||||
def test_to_categorical():
|
||||
test_tensor = torch.stack([
|
||||
torch.cat([torch.eye(5, dtype=int), torch.zeros((5, 5), dtype=int)]),
|
||||
torch.cat([torch.zeros((5, 5), dtype=int), torch.eye(5, dtype=int)])
|
||||
]).to(torch.float)
|
||||
|
||||
expected = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
|
||||
assert expected.shape == (2, 5)
|
||||
assert test_tensor.shape == (2, 10, 5)
|
||||
|
||||
result = to_categorical(test_tensor)
|
||||
|
||||
assert result.shape == expected.shape
|
||||
assert torch.allclose(result, expected.to(result.dtype))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'num_classes', 'expected_num_classes'], [
|
||||
pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), 10, 10),
|
||||
pytest.param(torch.rand(32, 10, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
|
||||
pytest.param(torch.rand(32, 28, 28), torch.randint(10, (32, 28, 28)), None, 10),
|
||||
])
|
||||
def test_get_num_classes(pred, target, num_classes, expected_num_classes):
|
||||
assert get_num_classes(pred, target, num_classes) == expected_num_classes
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp',
|
||||
'expected_tn', 'expected_fn', 'expected_support'], [
|
||||
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2),
|
||||
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 1, 1, 1, 1, 2)
|
||||
])
|
||||
def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
|
||||
tp, fp, tn, fn, sup = stat_scores(pred, target, class_index=4)
|
||||
|
||||
assert tp.item() == expected_tp
|
||||
assert fp.item() == expected_fp
|
||||
assert tn.item() == expected_tn
|
||||
assert fn.item() == expected_fn
|
||||
assert sup.item() == expected_support
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'reduction', 'expected_tp', 'expected_fp',
|
||||
'expected_tn', 'expected_fn', 'expected_support'], [
|
||||
pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 'none',
|
||||
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]),
|
||||
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'none',
|
||||
[1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]),
|
||||
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'sum',
|
||||
torch.tensor(2), torch.tensor(2), torch.tensor(14), torch.tensor(2), torch.tensor(4)),
|
||||
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'elementwise_mean',
|
||||
torch.tensor(0.4), torch.tensor(0.4), torch.tensor(2.8), torch.tensor(0.4), torch.tensor(0.8))
|
||||
])
|
||||
def test_stat_scores_multiclass(pred, target, reduction, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
|
||||
tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target, reduction=reduction)
|
||||
|
||||
assert torch.allclose(torch.tensor(expected_tp).to(tp), tp)
|
||||
assert torch.allclose(torch.tensor(expected_fp).to(fp), fp)
|
||||
assert torch.allclose(torch.tensor(expected_tn).to(tn), tn)
|
||||
assert torch.allclose(torch.tensor(expected_fn).to(fn), fn)
|
||||
assert torch.allclose(torch.tensor(expected_support).to(sup), sup)
|
||||
|
||||
|
||||
def test_multilabel_accuracy():
|
||||
# Dense label indicator matrix format
|
||||
y1 = torch.tensor([[0, 1, 1], [1, 0, 1]])
|
||||
y2 = torch.tensor([[0, 0, 1], [1, 0, 1]])
|
||||
|
||||
assert torch.allclose(accuracy(y1, y2, class_reduction='none'), torch.tensor([2 / 3, 1.]))
|
||||
assert torch.allclose(accuracy(y1, y1, class_reduction='none'), torch.tensor([1., 1.]))
|
||||
assert torch.allclose(accuracy(y2, y2, class_reduction='none'), torch.tensor([1., 1.]))
|
||||
assert torch.allclose(accuracy(y2, torch.logical_not(y2), class_reduction='none'), torch.tensor([0., 0.]))
|
||||
assert torch.allclose(accuracy(y1, torch.logical_not(y1), class_reduction='none'), torch.tensor([0., 0.]))
|
||||
|
||||
# num_classes does not match extracted number from input we expect a warning
|
||||
with pytest.warns(RuntimeWarning,
|
||||
match=r'You have set .* number of classes which is'
|
||||
r' different from predicted (.*) and'
|
||||
r' target (.*) number of classes'):
|
||||
_ = accuracy(y2, torch.zeros_like(y2), num_classes=3)
|
||||
|
||||
|
||||
def test_accuracy():
|
||||
pred = torch.tensor([0, 1, 2, 3])
|
||||
target = torch.tensor([0, 1, 2, 2])
|
||||
acc = accuracy(pred, target)
|
||||
|
||||
assert acc.item() == 0.75
|
||||
|
||||
pred = torch.tensor([0, 1, 2, 2])
|
||||
target = torch.tensor([0, 1, 1, 3])
|
||||
acc = accuracy(pred, target)
|
||||
|
||||
assert acc.item() == 0.50
|
||||
|
||||
|
||||
def test_confusion_matrix():
|
||||
target = (torch.arange(120) % 3).view(-1, 1)
|
||||
pred = target.clone()
|
||||
cm = confusion_matrix(pred, target, normalize=True)
|
||||
|
||||
assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]))
|
||||
|
||||
pred = torch.zeros_like(pred)
|
||||
cm = confusion_matrix(pred, target, normalize=True)
|
||||
assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]]))
|
||||
|
||||
target = torch.LongTensor([0, 0, 0, 0, 0])
|
||||
pred = target.clone()
|
||||
cm = confusion_matrix(pred, target, normalize=False, num_classes=3)
|
||||
assert torch.allclose(cm, torch.tensor([[5., 0., 0.], [0., 0., 0.], [0., 0., 0.]]))
|
||||
|
||||
# Example taken from https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
|
||||
target = torch.LongTensor([0] * 13 + [1] * 16 + [2] * 9)
|
||||
pred = torch.LongTensor([0] * 13 + [1] * 10 + [2] * 15)
|
||||
cm = confusion_matrix(pred, target, normalize=False, num_classes=3)
|
||||
assert torch.allclose(cm, torch.tensor([[13., 0., 0.], [0., 10., 6.], [0., 0., 9.]]))
|
||||
to_compare = cm / torch.tensor([[13.], [16.], [9.]])
|
||||
|
||||
cm = confusion_matrix(pred, target, normalize=True, num_classes=3)
|
||||
assert torch.allclose(cm, to_compare)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_prec', 'expected_rec'], [
|
||||
pytest.param(torch.tensor([1., 0., 1., 0.]), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]),
|
||||
pytest.param(to_onehot(torch.tensor([1., 0., 1., 0.])), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5])
|
||||
])
|
||||
def test_precision_recall(pred, target, expected_prec, expected_rec):
|
||||
prec = precision(pred, target, class_reduction='none')
|
||||
rec = recall(pred, target, class_reduction='none')
|
||||
|
||||
assert torch.allclose(torch.tensor(expected_prec).to(prec), prec)
|
||||
assert torch.allclose(torch.tensor(expected_rec).to(rec), rec)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'beta', 'exp_score'], [
|
||||
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 0.5, [0.5, 0.5]),
|
||||
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 1, [0.5, 0.5]),
|
||||
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 2, [0.5, 0.5]),
|
||||
])
|
||||
def test_fbeta_score(pred, target, beta, exp_score):
|
||||
score = fbeta_score(torch.tensor(pred), torch.tensor(target), beta, class_reduction='none')
|
||||
assert torch.allclose(score, torch.tensor(exp_score))
|
||||
|
||||
score = fbeta_score(to_onehot(torch.tensor(pred)), torch.tensor(target), beta, class_reduction='none')
|
||||
assert torch.allclose(score, torch.tensor(exp_score))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'exp_score'], [
|
||||
pytest.param([0., 0., 0., 0.], [1., 1., 1., 1.], [0.0, 0.0]),
|
||||
pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], [0.5, 0.5]),
|
||||
pytest.param([1., 0., 1., 0.], [1., 0., 1., 0.], [1.0, 1.0]),
|
||||
])
|
||||
def test_f1_score(pred, target, exp_score):
|
||||
score = f1_score(torch.tensor(pred), torch.tensor(target), class_reduction='none')
|
||||
assert torch.allclose(score, torch.tensor(exp_score))
|
||||
|
||||
score = f1_score(to_onehot(torch.tensor(pred)), torch.tensor(target), class_reduction='none')
|
||||
assert torch.allclose(score, torch.tensor(exp_score))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['sample_weight', 'pos_label', "exp_shape"], [
|
||||
pytest.param(1, 1., 42),
|
||||
pytest.param(None, 1., 42),
|
||||
])
|
||||
def test_binary_clf_curve(sample_weight, pos_label, exp_shape):
|
||||
# TODO: move back the pred and target to test func arguments
|
||||
# if you fix the array inside the function, you'd also have fix the shape,
|
||||
# because when the array changes, you also have to fix the shape
|
||||
seed_everything(0)
|
||||
pred = torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100
|
||||
target = torch.tensor([0, 1] * 50, dtype=torch.int)
|
||||
if sample_weight is not None:
|
||||
sample_weight = torch.ones_like(pred) * sample_weight
|
||||
|
||||
fps, tps, thresh = _binary_clf_curve(pred, target, sample_weight, pos_label)
|
||||
|
||||
assert isinstance(tps, torch.Tensor)
|
||||
assert isinstance(fps, torch.Tensor)
|
||||
assert isinstance(thresh, torch.Tensor)
|
||||
assert tps.shape == (exp_shape,)
|
||||
assert fps.shape == (exp_shape,)
|
||||
assert thresh.shape == (exp_shape,)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], [
|
||||
pytest.param([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1., 1.], [1, 0.5, 0.5, 0.5, 0.], [1, 2, 3, 4])
|
||||
])
|
||||
def test_pr_curve(pred, target, expected_p, expected_r, expected_t):
|
||||
p, r, t = precision_recall_curve(torch.tensor(pred), torch.tensor(target))
|
||||
assert p.size() == r.size()
|
||||
assert p.size(0) == t.size(0) + 1
|
||||
|
||||
assert torch.allclose(p, torch.tensor(expected_p).to(p))
|
||||
assert torch.allclose(r, torch.tensor(expected_r).to(r))
|
||||
assert torch.allclose(t, torch.tensor(expected_t).to(t))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [
|
||||
pytest.param([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]),
|
||||
pytest.param([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]),
|
||||
pytest.param([1, 1], [1, 0], [0, 1], [0, 1]),
|
||||
pytest.param([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]),
|
||||
pytest.param([0.5, 0.5], [0, 1], [0, 1], [0, 1]),
|
||||
])
|
||||
def test_roc_curve(pred, target, expected_tpr, expected_fpr):
|
||||
fpr, tpr, thresh = roc(torch.tensor(pred), torch.tensor(target))
|
||||
|
||||
assert fpr.shape == tpr.shape
|
||||
assert fpr.size(0) == thresh.size(0)
|
||||
assert torch.allclose(fpr, torch.tensor(expected_fpr).to(fpr))
|
||||
assert torch.allclose(tpr, torch.tensor(expected_tpr).to(tpr))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||
pytest.param([0, 1, 0, 1], [0, 1, 0, 1], 1.),
|
||||
pytest.param([1, 1, 0, 0], [0, 0, 1, 1], 0.),
|
||||
pytest.param([1, 1, 1, 1], [1, 1, 0, 0], 0.5),
|
||||
pytest.param([1, 1, 0, 0], [1, 1, 0, 0], 1.),
|
||||
pytest.param([0.5, 0.5, 0.5, 0.5], [1, 1, 0, 0], 0.5),
|
||||
])
|
||||
def test_auroc(pred, target, expected):
|
||||
score = auroc(torch.tensor(pred), torch.tensor(target)).item()
|
||||
assert score == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['x', 'y', 'expected'], [
|
||||
pytest.param([0, 1], [0, 1], 0.5),
|
||||
pytest.param([1, 0], [0, 1], 0.5),
|
||||
pytest.param([1, 0, 0], [0, 1, 1], 0.5),
|
||||
pytest.param([0, 1], [1, 1], 1),
|
||||
pytest.param([0, 0.5, 1], [0, 0.5, 1], 0.5),
|
||||
])
|
||||
def test_auc(x, y, expected):
|
||||
# Test Area Under Curve (AUC) computation
|
||||
assert auc(torch.tensor(x), torch.tensor(y)) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['scores', 'target', 'expected_score'], [
|
||||
# Check the average_precision_score of a constant predictor is
|
||||
# the TPR
|
||||
# Generate a dataset with 25% of positives
|
||||
# And a constant score
|
||||
# The precision is then the fraction of positive whatever the recall
|
||||
# is, as there is only one threshold:
|
||||
pytest.param(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 0, 0, 1]), .25),
|
||||
# With threshold 0.8 : 1 TP and 2 TN and one FN
|
||||
pytest.param(torch.tensor([.6, .7, .8, 9]), torch.tensor([1, 0, 0, 1]), .75),
|
||||
])
|
||||
def test_average_precision(scores, target, expected_score):
|
||||
assert average_precision(scores, target) == expected_score
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||
pytest.param([[0, 0], [1, 1]], [[0, 0], [1, 1]], 1.),
|
||||
pytest.param([[1, 1], [0, 0]], [[0, 0], [1, 1]], 0.),
|
||||
pytest.param([[1, 1], [1, 1]], [[1, 1], [0, 0]], 2 / 3),
|
||||
pytest.param([[1, 1], [0, 0]], [[1, 1], [0, 0]], 1.),
|
||||
])
|
||||
def test_dice_score(pred, target, expected):
|
||||
score = dice_score(torch.tensor(pred), torch.tensor(target))
|
||||
assert score == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [
|
||||
pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])),
|
||||
pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])),
|
||||
pytest.param(False, 'none', 0, torch.Tensor([1, 1])),
|
||||
pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])),
|
||||
pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])),
|
||||
pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])),
|
||||
])
|
||||
def test_iou(half_ones, reduction, ignore_index, expected):
|
||||
pred = (torch.arange(120) % 3).view(-1, 1)
|
||||
target = (torch.arange(120) % 3).view(-1, 1)
|
||||
if half_ones:
|
||||
pred[:60] = 1
|
||||
iou_val = iou(
|
||||
pred=pred,
|
||||
target=target,
|
||||
ignore_index=ignore_index,
|
||||
reduction=reduction,
|
||||
)
|
||||
assert torch.allclose(iou_val, expected, atol=1e-9)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('metric', [auroc])
|
||||
def test_error_on_multiclass_input(metric):
|
||||
""" check that these metrics raise an error if they are used for multiclass problems """
|
||||
pred = torch.randint(0, 10, (100, ))
|
||||
target = torch.randint(0, 10, (100, ))
|
||||
with pytest.raises(ValueError, match="AUROC metric is meant for binary classification"):
|
||||
_ = metric(pred, target)
|
||||
|
||||
|
||||
# TODO: When the jaccard_score of the sklearn version we use accepts `zero_division` (see
|
||||
# https://github.com/scikit-learn/scikit-learn/pull/17866), consider adding a test here against our
|
||||
# `absent_score`.
|
||||
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'absent_score', 'num_classes', 'expected'], [
|
||||
# Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid
|
||||
# scores the function can return ([0., 1.] range, inclusive).
|
||||
# 2 classes, class 0 is correct everywhere, class 1 is absent.
|
||||
pytest.param([0], [0], None, -1., 2, [1., -1.]),
|
||||
pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]),
|
||||
# absent_score not applied if only class 0 is present and it's the only class.
|
||||
pytest.param([0], [0], None, -1., 1, [1.]),
|
||||
# 2 classes, class 1 is correct everywhere, class 0 is absent.
|
||||
pytest.param([1], [1], None, -1., 2, [-1., 1.]),
|
||||
pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]),
|
||||
# When 0 index ignored, class 0 does not get a score (not even the absent_score).
|
||||
pytest.param([1], [1], 0, -1., 2, [1.0]),
|
||||
# 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score.
|
||||
pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]),
|
||||
pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]),
|
||||
# 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score.
|
||||
pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]),
|
||||
pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]),
|
||||
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class
|
||||
# 2 is absent.
|
||||
pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]),
|
||||
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class
|
||||
# 2 is absent.
|
||||
pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]),
|
||||
# Sanity checks with absent_score of 1.0.
|
||||
pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]),
|
||||
pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]),
|
||||
])
|
||||
def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected):
|
||||
iou_val = iou(
|
||||
pred=torch.tensor(pred),
|
||||
target=torch.tensor(target),
|
||||
ignore_index=ignore_index,
|
||||
absent_score=absent_score,
|
||||
num_classes=num_classes,
|
||||
reduction='none',
|
||||
)
|
||||
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))
|
||||
|
||||
|
||||
# example data taken from
|
||||
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
|
||||
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], [
|
||||
# Ignoring an index outside of [0, num_classes-1] should have no effect.
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]),
|
||||
# Ignoring a valid index drops only that index from the result.
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]),
|
||||
# When reducing to mean or sum, the ignored index does not contribute to the output.
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]),
|
||||
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]),
|
||||
])
|
||||
def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected):
|
||||
iou_val = iou(
|
||||
pred=torch.tensor(pred),
|
||||
target=torch.tensor(target),
|
||||
ignore_index=ignore_index,
|
||||
num_classes=num_classes,
|
||||
reduction=reduction,
|
||||
)
|
||||
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))
|
|
@ -1,66 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu
|
||||
|
||||
from pytorch_lightning.metrics.functional.nlp import bleu_score
|
||||
|
||||
# example taken from
|
||||
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.sentence_bleu
|
||||
HYPOTHESIS1 = tuple(
|
||||
"It is a guide to action which ensures that the military always obeys the commands of the party".split()
|
||||
)
|
||||
REFERENCE1 = tuple("It is a guide to action that ensures that the military will forever heed Party commands".split())
|
||||
REFERENCE2 = tuple(
|
||||
"It is a guiding principle which makes the military forces always being under the command of the Party".split()
|
||||
)
|
||||
REFERENCE3 = tuple("It is the practical guide for the army always to heed the directions of the party".split())
|
||||
|
||||
|
||||
# example taken from
|
||||
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu
|
||||
HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split()
|
||||
HYP2 = "he read the book because he was interested in world history".split()
|
||||
|
||||
REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split()
|
||||
REF1B = "It is a guiding principle which makes the military force always being under the command of the Party".split()
|
||||
REF1C = "It is the practical guide for the army always to heed the directions of the party".split()
|
||||
REF2A = "he was interested in world history because he read the book".split()
|
||||
|
||||
LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]]
|
||||
HYPOTHESES = [HYP1, HYP2]
|
||||
|
||||
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.SmoothingFunction
|
||||
smooth_func = SmoothingFunction().method2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["weights", "n_gram", "smooth_func", "smooth"],
|
||||
[
|
||||
pytest.param([1], 1, None, False),
|
||||
pytest.param([0.5, 0.5], 2, smooth_func, True),
|
||||
pytest.param([0.333333, 0.333333, 0.333333], 3, None, False),
|
||||
pytest.param([0.25, 0.25, 0.25, 0.25], 4, smooth_func, True),
|
||||
],
|
||||
)
|
||||
def test_bleu_score(weights, n_gram, smooth_func, smooth):
|
||||
nltk_output = sentence_bleu(
|
||||
[REFERENCE1, REFERENCE2, REFERENCE3], HYPOTHESIS1, weights=weights, smoothing_function=smooth_func
|
||||
)
|
||||
pl_output = bleu_score([HYPOTHESIS1], [[REFERENCE1, REFERENCE2, REFERENCE3]], n_gram=n_gram, smooth=smooth)
|
||||
assert torch.allclose(pl_output, torch.tensor(nltk_output))
|
||||
|
||||
nltk_output = corpus_bleu(LIST_OF_REFERENCES, HYPOTHESES, weights=weights, smoothing_function=smooth_func)
|
||||
pl_output = bleu_score(HYPOTHESES, LIST_OF_REFERENCES, n_gram=n_gram, smooth=smooth)
|
||||
assert torch.allclose(pl_output, torch.tensor(nltk_output))
|
||||
|
||||
|
||||
def test_bleu_empty():
|
||||
hyp = [[]]
|
||||
ref = [[[]]]
|
||||
assert bleu_score(hyp, ref) == torch.tensor(0.0)
|
||||
|
||||
|
||||
def test_no_4_gram():
|
||||
hyps = [["My", "full", "pytorch-lightning"]]
|
||||
refs = [[["My", "full", "pytorch-lightning", "test"], ["Completely", "Different"]]]
|
||||
assert bleu_score(hyps, refs) == torch.tensor(0.0)
|
|
@ -1,30 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.reduction import reduce, class_reduce
|
||||
|
||||
|
||||
def test_reduce():
|
||||
start_tensor = torch.rand(50, 40, 30)
|
||||
|
||||
assert torch.allclose(reduce(start_tensor, 'elementwise_mean'), torch.mean(start_tensor))
|
||||
assert torch.allclose(reduce(start_tensor, 'sum'), torch.sum(start_tensor))
|
||||
assert torch.allclose(reduce(start_tensor, 'none'), start_tensor)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
reduce(start_tensor, 'error_reduction')
|
||||
|
||||
|
||||
def test_class_reduce():
|
||||
num = torch.randint(1, 10, (100,)).float()
|
||||
denom = torch.randint(10, 20, (100,)).float()
|
||||
weights = torch.randint(1, 100, (100,)).float()
|
||||
|
||||
assert torch.allclose(class_reduce(num, denom, weights, 'micro'),
|
||||
torch.sum(num) / torch.sum(denom))
|
||||
assert torch.allclose(class_reduce(num, denom, weights, 'macro'),
|
||||
torch.mean(num / denom))
|
||||
assert torch.allclose(class_reduce(num, denom, weights, 'weighted'),
|
||||
torch.sum(num / denom * (weights / torch.sum(weights))))
|
||||
assert torch.allclose(class_reduce(num, denom, weights, 'none'),
|
||||
num / denom)
|
|
@ -1,175 +0,0 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from functools import partial
|
||||
from math import sqrt
|
||||
from skimage.metrics import (
|
||||
peak_signal_noise_ratio as ski_psnr,
|
||||
structural_similarity as ski_ssim
|
||||
)
|
||||
from sklearn.metrics import (
|
||||
mean_absolute_error as mae_sk,
|
||||
mean_squared_error as mse_sk,
|
||||
mean_squared_log_error as msle_sk
|
||||
)
|
||||
|
||||
from pytorch_lightning.metrics.functional import (
|
||||
mae,
|
||||
mse,
|
||||
psnr,
|
||||
rmse,
|
||||
rmsle,
|
||||
ssim
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
|
||||
pytest.param(mae_sk, mae, id='mean_absolute_error'),
|
||||
pytest.param(mse_sk, mse, id='mean_squared_error'),
|
||||
pytest.param(partial(mse_sk, squared=False), rmse, id='root_mean_squared_error'),
|
||||
pytest.param(lambda x, y: sqrt(msle_sk(x, y)), rmsle, id='root_mean_squared_log_error')
|
||||
])
|
||||
def test_against_sklearn(sklearn_metric, torch_metric):
|
||||
"""Compare PL metrics to sklearn version."""
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
# iterate over different label counts in predictions and target
|
||||
pred = torch.rand(300, device=device)
|
||||
target = torch.rand(300, device=device)
|
||||
|
||||
sk_score = sklearn_metric(target.cpu().detach().numpy(),
|
||||
pred.cpu().detach().numpy())
|
||||
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
|
||||
pl_score = torch_metric(pred, target)
|
||||
assert torch.allclose(sk_score, pl_score)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.25),
|
||||
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 3.0),
|
||||
])
|
||||
def test_mse(pred, target, expected):
|
||||
score = mse(torch.tensor(pred), torch.tensor(target))
|
||||
assert score.item() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||
pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0),
|
||||
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.5),
|
||||
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 1.7321),
|
||||
])
|
||||
def test_rmse(pred, target, expected):
|
||||
score = rmse(torch.tensor(pred), torch.tensor(target))
|
||||
assert torch.allclose(score, torch.tensor(expected), atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||
pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0),
|
||||
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.25),
|
||||
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 1.5),
|
||||
])
|
||||
def test_mae(pred, target, expected):
|
||||
score = mae(torch.tensor(pred), torch.tensor(target))
|
||||
assert score.item() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'expected'], [
|
||||
pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0),
|
||||
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.1438),
|
||||
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 0.5330),
|
||||
])
|
||||
def test_rmsle(pred, target, expected):
|
||||
score = rmsle(torch.tensor(pred), torch.tensor(target))
|
||||
assert torch.allclose(score, torch.tensor(expected), atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target'], [
|
||||
pytest.param([0., 1., 2., 3.], [0., 1., 2., 3.]),
|
||||
pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.]),
|
||||
pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.]),
|
||||
])
|
||||
def test_psnr_with_skimage(pred, target):
|
||||
score = psnr(pred=torch.tensor(pred),
|
||||
target=torch.tensor(target), data_range=3)
|
||||
sk_score = ski_psnr(np.array(pred), np.array(target), data_range=3)
|
||||
assert torch.allclose(score, torch.tensor(sk_score, dtype=torch.float), atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target'], [
|
||||
pytest.param([0., 1., 2., 3.], [0., 1., 2., 2.]),
|
||||
pytest.param([4., 3., 2., 1.], [1., 4., 3., 2.]),
|
||||
])
|
||||
def test_psnr_base_e_wider_range(pred, target):
|
||||
score = psnr(pred=torch.tensor(pred),
|
||||
target=torch.tensor(target),
|
||||
data_range=4,
|
||||
base=2.718281828459045)
|
||||
sk_score = ski_psnr(np.array(pred), np.array(target), data_range=4) * np.log(10)
|
||||
assert torch.allclose(score, torch.tensor(sk_score, dtype=torch.float32), atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
|
||||
pytest.param(ski_psnr, psnr, id='peak_signal_noise_ratio')
|
||||
])
|
||||
def test_psnr_against_sklearn(sklearn_metric, torch_metric):
|
||||
"""Compare PL metrics to sklearn version."""
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]:
|
||||
pred = torch.randint(n_cls_pred, (500,), device=device, dtype=torch.float)
|
||||
target = torch.randint(n_cls_target, (500,), device=device, dtype=torch.float)
|
||||
|
||||
sk_score = sklearn_metric(target.cpu().detach().numpy(),
|
||||
pred.cpu().detach().numpy(),
|
||||
data_range=n_cls_target)
|
||||
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
|
||||
pl_score = torch_metric(pred, target, data_range=n_cls_target)
|
||||
assert torch.allclose(sk_score, pl_score)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['size', 'channel', 'coef', 'multichannel'], [
|
||||
pytest.param(16, 1, 0.9, False),
|
||||
pytest.param(32, 3, 0.8, True),
|
||||
pytest.param(48, 4, 0.7, True),
|
||||
pytest.param(64, 5, 0.6, True)
|
||||
])
|
||||
def test_ssim(size, channel, coef, multichannel):
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
pred = torch.rand(size, channel, size, size, device=device)
|
||||
target = pred * coef
|
||||
ssim_idx = ssim(pred, target, data_range=1.0)
|
||||
np_pred = pred.permute(0, 2, 3, 1).cpu().numpy()
|
||||
if multichannel is False:
|
||||
np_pred = np_pred[:, :, :, 0]
|
||||
np_target = np.multiply(np_pred, coef)
|
||||
sk_ssim_idx = ski_ssim(
|
||||
np_pred, np_target, win_size=11, multichannel=multichannel, gaussian_weights=True, data_range=1.0
|
||||
)
|
||||
assert torch.allclose(ssim_idx, torch.tensor(sk_ssim_idx, dtype=torch.float, device=device), atol=1e-4)
|
||||
|
||||
ssim_idx = ssim(pred, pred)
|
||||
assert torch.allclose(ssim_idx, torch.tensor(1.0, device=device))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['pred', 'target', 'kernel', 'sigma'], [
|
||||
pytest.param([1, 1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # shape
|
||||
pytest.param([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape)
|
||||
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma)
|
||||
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma)
|
||||
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma)
|
||||
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input
|
||||
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input
|
||||
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input
|
||||
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input
|
||||
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input
|
||||
])
|
||||
def test_ssim_invalid_inputs(pred, target, kernel, sigma):
|
||||
pred_t = torch.rand(pred)
|
||||
target_t = torch.rand(target, dtype=torch.float64)
|
||||
with pytest.raises(TypeError):
|
||||
ssim(pred_t, target_t)
|
||||
|
||||
pred = torch.rand(pred)
|
||||
target = torch.rand(target)
|
||||
with pytest.raises(ValueError):
|
||||
ssim(pred, target, kernel, sigma)
|
|
@ -1,35 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import pairwise
|
||||
|
||||
from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity
|
||||
|
||||
|
||||
@pytest.mark.parametrize('similarity', ['cosine', 'dot'])
|
||||
@pytest.mark.parametrize('reduction', ['none', 'mean', 'sum'])
|
||||
def test_against_sklearn(similarity, reduction):
|
||||
"""Compare PL metrics to sklearn version."""
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
|
||||
batch = torch.randn(5, 10, device=device) # 100 samples in 10 dimensions
|
||||
|
||||
pl_dist = embedding_similarity(batch, similarity=similarity,
|
||||
reduction=reduction, zero_diagonal=False)
|
||||
|
||||
def sklearn_embedding_distance(batch, similarity, reduction):
|
||||
|
||||
metric_func = {'cosine': pairwise.cosine_similarity,
|
||||
'dot': pairwise.linear_kernel}[similarity]
|
||||
|
||||
dist = metric_func(batch, batch)
|
||||
if reduction == 'mean':
|
||||
return dist.mean(axis=-1)
|
||||
if reduction == 'sum':
|
||||
return dist.sum(axis=-1)
|
||||
return dist
|
||||
|
||||
sk_dist = sklearn_embedding_distance(batch.cpu().detach().numpy(),
|
||||
similarity=similarity, reduction=reduction)
|
||||
sk_dist = torch.tensor(sk_dist, dtype=torch.float, device=device)
|
||||
|
||||
assert torch.allclose(sk_dist, pl_dist)
|
|
@ -0,0 +1,57 @@
|
|||
import torch
|
||||
import pytest
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
|
||||
from pytorch_lightning.metrics.regression import MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError
|
||||
from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_squared_log_error
|
||||
|
||||
from tests.metrics.utils import compute_batch, setup_ddp
|
||||
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE
|
||||
|
||||
num_targets = 5
|
||||
|
||||
Input = namedtuple('Input', ["preds", "target"])
|
||||
|
||||
_single_target_inputs = Input(
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
|
||||
target=torch.rand(NUM_BATCHES, BATCH_SIZE),
|
||||
)
|
||||
|
||||
_multi_target_inputs = Input(
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
|
||||
target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
|
||||
)
|
||||
|
||||
|
||||
def _single_target_sk_metric(preds, target, sk_fn=mean_squared_error):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
return sk_fn(sk_preds, sk_target)
|
||||
|
||||
|
||||
def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error):
|
||||
sk_preds = preds.view(-1, num_targets).numpy()
|
||||
sk_target = target.view(-1, num_targets).numpy()
|
||||
return sk_fn(sk_preds, sk_target)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
@pytest.mark.parametrize("ddp_sync_on_step", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"preds, target, sk_metric",
|
||||
[
|
||||
(_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric),
|
||||
(_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"metric_class, sk_fn",
|
||||
[
|
||||
(MeanSquaredError, mean_squared_error),
|
||||
(MeanAbsoluteError, mean_absolute_error),
|
||||
(MeanSquaredLogError, mean_squared_log_error),
|
||||
],
|
||||
)
|
||||
def test_mean_error(ddp, ddp_sync_on_step, preds, target, sk_metric, metric_class, sk_fn):
|
||||
compute_batch(preds, target, metric_class, partial(sk_metric, sk_fn=sk_fn), ddp_sync_on_step, ddp)
|
|
@ -1,297 +0,0 @@
|
|||
import pytest
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import numpy as np
|
||||
|
||||
from tests.base import EvalModelTemplate
|
||||
from pytorch_lightning import Trainer
|
||||
import tests.base.develop_utils as tutils
|
||||
from pytorch_lightning.metrics import (
|
||||
Accuracy,
|
||||
ConfusionMatrix,
|
||||
PrecisionRecallCurve,
|
||||
Precision,
|
||||
Recall,
|
||||
AveragePrecision,
|
||||
AUROC,
|
||||
FBeta,
|
||||
F1,
|
||||
ROC,
|
||||
MulticlassROC,
|
||||
MulticlassPrecisionRecallCurve,
|
||||
DiceCoefficient,
|
||||
IoU,
|
||||
MAE,
|
||||
MSE,
|
||||
RMSE,
|
||||
RMSLE,
|
||||
PSNR,
|
||||
SSIM,
|
||||
)
|
||||
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
confusion_matrix,
|
||||
precision_recall_curve,
|
||||
precision_score,
|
||||
recall_score,
|
||||
average_precision_score,
|
||||
roc_auc_score,
|
||||
fbeta_score,
|
||||
f1_score,
|
||||
roc_curve,
|
||||
jaccard_score,
|
||||
mean_squared_error,
|
||||
mean_absolute_error,
|
||||
mean_squared_log_error
|
||||
)
|
||||
|
||||
from skimage.metrics import (
|
||||
peak_signal_noise_ratio,
|
||||
structural_similarity
|
||||
)
|
||||
|
||||
# example structure
|
||||
TestCase = namedtuple('example', ['name', 'lightning_metric', 'comparing_metric', 'test_input'])
|
||||
|
||||
# setup some standard testcases
|
||||
NB_SAMPLES = 200
|
||||
multiclass_example = [(torch.randint(10, (NB_SAMPLES,)), torch.randint(10, (NB_SAMPLES,)))]
|
||||
binary_example = [(torch.randint(2, (NB_SAMPLES,)), torch.randint(2, (NB_SAMPLES,)))]
|
||||
multiclass_and_binary_example = [*multiclass_example, *binary_example]
|
||||
binary_example_logits = (torch.randint(2, (NB_SAMPLES,)), torch.randint(5, (NB_SAMPLES,)))
|
||||
multiclass_example_probs = (torch.randint(10, (NB_SAMPLES,)), torch.randn((NB_SAMPLES, 10)).softmax(-1))
|
||||
regression_example = [(torch.rand((NB_SAMPLES,)), torch.rand((NB_SAMPLES,)))]
|
||||
|
||||
|
||||
# construct additional test functions
|
||||
def root_mean_squared_error(x, y):
|
||||
return math.sqrt(mean_squared_error(x, y))
|
||||
|
||||
|
||||
def root_mean_squared_log_error(x, y):
|
||||
return math.sqrt(mean_squared_log_error(x, y))
|
||||
|
||||
|
||||
# Define testcases
|
||||
# TODO: update remaining metrics and uncomment the corresponding test cases
|
||||
TESTS = [
|
||||
TestCase('accuracy',
|
||||
Accuracy,
|
||||
accuracy_score,
|
||||
multiclass_and_binary_example),
|
||||
TestCase('confusion matrix without normalize',
|
||||
ConfusionMatrix,
|
||||
confusion_matrix,
|
||||
multiclass_and_binary_example),
|
||||
TestCase('confusion matrix with normalize',
|
||||
partial(ConfusionMatrix, normalize=True),
|
||||
partial(confusion_matrix, normalize='true'),
|
||||
multiclass_and_binary_example),
|
||||
# TestCase('precision recall curve',
|
||||
# PrecisionRecallCurve,
|
||||
# precision_recall_curve,
|
||||
# binary_example),
|
||||
TestCase('precision',
|
||||
Precision,
|
||||
partial(precision_score, average='micro'),
|
||||
multiclass_and_binary_example),
|
||||
TestCase('recall',
|
||||
Recall,
|
||||
partial(recall_score, average='micro'),
|
||||
multiclass_and_binary_example),
|
||||
# TestCase('average_precision',
|
||||
# AveragePrecision,
|
||||
# average_precision_score,
|
||||
# binary_example),
|
||||
# TestCase('auroc',
|
||||
# AUROC,
|
||||
# roc_auc_score,
|
||||
# binary_example),
|
||||
TestCase('f beta',
|
||||
partial(FBeta, beta=2),
|
||||
partial(fbeta_score, average='micro', beta=2),
|
||||
multiclass_and_binary_example),
|
||||
TestCase('f1',
|
||||
F1,
|
||||
partial(f1_score, average='micro'),
|
||||
multiclass_and_binary_example),
|
||||
# TestCase('roc',
|
||||
# ROC,
|
||||
# roc_curve,
|
||||
# binary_example),
|
||||
# TestCase('multiclass roc',
|
||||
# MulticlassROC,
|
||||
# multiclass_roc,
|
||||
# binary_example),
|
||||
# TestCase('multiclass precision recall curve',
|
||||
# MulticlassPrecisionRecallCurve,
|
||||
# multiclass_precision_recall_curve,
|
||||
# binary_example),
|
||||
# TestCase('dice coefficient',
|
||||
# DiceCoefficient,
|
||||
# partial(f1_score, average='micro'),
|
||||
# multiclass_and_binary_example),
|
||||
# TestCase('intersection over union',
|
||||
# IoU,
|
||||
# partial(jaccard_score, average='macro'),
|
||||
# binary_example),
|
||||
TestCase('mean squared error',
|
||||
MSE,
|
||||
mean_squared_error,
|
||||
regression_example),
|
||||
TestCase('root mean squared error',
|
||||
RMSE,
|
||||
root_mean_squared_error,
|
||||
regression_example),
|
||||
TestCase('mean absolute error',
|
||||
MAE,
|
||||
mean_absolute_error,
|
||||
regression_example),
|
||||
TestCase('root mean squared log error',
|
||||
RMSLE,
|
||||
root_mean_squared_log_error,
|
||||
regression_example),
|
||||
TestCase('peak signal-to-noise ratio',
|
||||
partial(PSNR, data_range=10),
|
||||
partial(peak_signal_noise_ratio, data_range=10),
|
||||
regression_example),
|
||||
# TestCase('structual similarity index measure',
|
||||
# SSIM,
|
||||
# structural_similarity,
|
||||
# regression_example)
|
||||
]
|
||||
|
||||
|
||||
# Utility test functions
|
||||
def _idsfn(test):
|
||||
""" Return id for current example being tested """
|
||||
return test.name
|
||||
|
||||
|
||||
def _setup_ddp(rank, worldsize):
|
||||
""" setup ddp enviroment for testing """
|
||||
import os
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
# initialize the process group
|
||||
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
|
||||
|
||||
|
||||
def comparing_fn(lightning_val, comparing_val, rtol=1e-03, atol=1e-08):
|
||||
""" function for comparing output, both multi and single output"""
|
||||
# multi output
|
||||
if isinstance(comparing_val, tuple):
|
||||
for l_score, c_score in zip(lightning_val, comparing_val):
|
||||
assert np.allclose(l_score.numpy(), c_score, rtol, atol)
|
||||
else: # single output
|
||||
assert np.allclose(lightning_val.numpy(), comparing_val, rtol, atol)
|
||||
|
||||
|
||||
# ===== Tests start here =====
|
||||
def _test_ddp_single_batch(rank, worldsize, lightning_metric, comparing_metric, test_inputs):
|
||||
""" ddp testing function, divide test_inputs equally between all processes """
|
||||
_setup_ddp(rank, worldsize)
|
||||
|
||||
# Setup metric for ddp
|
||||
lightning_metric = lightning_metric()
|
||||
for test_input in test_inputs:
|
||||
# rank 0 receives sample 0,2,4,...
|
||||
# rank 1 receives sample 1,3,5,...
|
||||
lightning_val = lightning_metric(*[ti[rank::2] for ti in test_input])
|
||||
|
||||
comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)])
|
||||
|
||||
comparing_fn(lightning_val, comparing_val)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
|
||||
@pytest.mark.parametrize("test", TESTS, ids=_idsfn)
|
||||
def test_ddp(test):
|
||||
"""Make sure that metrics are correctly sync and reduced in DDP mode"""
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
|
||||
worldsize = 2
|
||||
mp.spawn(_test_ddp_single_batch,
|
||||
args=(worldsize,
|
||||
test.lightning_metric,
|
||||
test.comparing_metric,
|
||||
test.test_input),
|
||||
nprocs=worldsize)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test", TESTS, ids=_idsfn)
|
||||
def test_multi_batch(test):
|
||||
""" test that aggregation works for multiple batches """
|
||||
lightning_metric = test.lightning_metric()
|
||||
comparing_metric = test.comparing_metric
|
||||
|
||||
for test_input in test.test_input:
|
||||
for i in range(2): # for lightning device in 2 artificially batches
|
||||
# first batch consist of samples 0,2,4,...
|
||||
# second batch consist of samples 1,3,5,...
|
||||
_ = lightning_metric(*[ti[i::2] for ti in test_input])
|
||||
lightning_val = lightning_metric.aggregated
|
||||
comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)])
|
||||
|
||||
comparing_fn(lightning_val, comparing_val)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test", TESTS, ids=_idsfn)
|
||||
def test_multi_batch_unequal_sizes(test):
|
||||
""" test that aggregation works for multiple batches with uneven sizes """
|
||||
lightning_metric = test.lightning_metric()
|
||||
comparing_metric = test.comparing_metric
|
||||
|
||||
for test_input in test.test_input:
|
||||
for i in range(2): # for lightning device in 2 artificially batches
|
||||
if i == 0: # allocate 3/4 of data to the first batch
|
||||
_ = lightning_metric(*[ti[:int(3 / 4 * len(ti))] for ti in test_input])
|
||||
else:
|
||||
_ = lightning_metric(*[ti[int(3 / 4 * len(ti)):] for ti in test_input])
|
||||
lightning_val = lightning_metric.aggregated
|
||||
comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)])
|
||||
|
||||
comparing_fn(lightning_val, comparing_val)
|
||||
|
||||
|
||||
def _test_ddp_multi_batch(rank, worldsize, lightning_metric, comparing_metric, test_inputs):
|
||||
""" ddp testing function, test that metric works with aggregation over multiple
|
||||
devices and multiple batches """
|
||||
_setup_ddp(rank, worldsize)
|
||||
|
||||
# Setup metric for ddp
|
||||
lightning_metric = lightning_metric()
|
||||
for test_input in test_inputs:
|
||||
for i in range(2): # artificially divide samples between batches and processes
|
||||
# rank 0, batch 0 consist of samples 0,4,8,...
|
||||
# rank 0, batch 1 consist of samples 1,5,9,...
|
||||
# rank 1, batch 0 consist of samples 2,6,10,...
|
||||
# rank 1, batch 1 consist of samples 3,7,11,...
|
||||
_ = lightning_metric(*[ti[i + worldsize * rank::4] for ti in test_input])
|
||||
lightning_val = lightning_metric.aggregated
|
||||
comparing_val = comparing_metric(*[ti.numpy() for ti in reversed(test_input)])
|
||||
|
||||
comparing_fn(lightning_val, comparing_val)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
|
||||
@pytest.mark.parametrize("test", TESTS, ids=_idsfn)
|
||||
def test_ddp_multi_batch(test):
|
||||
""" test that aggregation works fine with in DDP mode and multiple batches """
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
|
||||
worldsize = 2
|
||||
mp.spawn(_test_ddp_multi_batch,
|
||||
args=(worldsize,
|
||||
test.lightning_metric,
|
||||
test.comparing_metric,
|
||||
test.test_input),
|
||||
nprocs=worldsize)
|
|
@ -1,237 +0,0 @@
|
|||
# NOTE: This file only tests if modules with arguments are running fine.
|
||||
# The actual metric implementation is tested in functional/test_classification.py
|
||||
# Especially reduction and reducing across processes won't be tested here!
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.classification import (
|
||||
Accuracy,
|
||||
ConfusionMatrix,
|
||||
PrecisionRecallCurve,
|
||||
Precision,
|
||||
Recall,
|
||||
AveragePrecision,
|
||||
AUROC,
|
||||
FBeta,
|
||||
F1,
|
||||
ROC,
|
||||
MulticlassROC,
|
||||
MulticlassPrecisionRecallCurve,
|
||||
DiceCoefficient,
|
||||
IoU,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random():
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_classes', [1, None])
|
||||
def test_accuracy(num_classes):
|
||||
acc = Accuracy(num_classes=num_classes)
|
||||
assert acc.name == 'accuracy'
|
||||
|
||||
result = acc(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]),
|
||||
target=torch.tensor([[0, 0, 1], [1, 0, 1]]))
|
||||
assert isinstance(result, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['normalize', 'num_classes'], [
|
||||
pytest.param(False, None),
|
||||
pytest.param(True, None),
|
||||
pytest.param(False, 3)
|
||||
])
|
||||
def test_confusion_matrix(normalize, num_classes):
|
||||
conf_matrix = ConfusionMatrix(normalize=normalize, num_classes=num_classes)
|
||||
assert conf_matrix.name == 'confusion_matrix'
|
||||
|
||||
target = (torch.arange(120) % 3).view(-1, 1)
|
||||
pred = target.clone()
|
||||
cm = conf_matrix(pred, target)
|
||||
assert isinstance(cm, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['normalize', 'num_classes'], [
|
||||
pytest.param(True, 3)
|
||||
])
|
||||
def test_confusion_matrix_norm(normalize, num_classes):
|
||||
""" test that user is warned if confusion matrix contains nans that are changed to zeros"""
|
||||
conf_matrix = ConfusionMatrix(normalize=normalize, num_classes=num_classes)
|
||||
assert conf_matrix.name == 'confusion_matrix'
|
||||
|
||||
with pytest.warns(UserWarning, match='6 nan values found in confusion matrix have been replaced with zeros.'):
|
||||
target = torch.LongTensor([0] * 5)
|
||||
pred = target.clone()
|
||||
cm = conf_matrix(pred, target)
|
||||
assert isinstance(cm, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('pos_label', [1, 2.])
|
||||
def test_precision_recall(pos_label):
|
||||
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1])
|
||||
|
||||
pr_curve = PrecisionRecallCurve(pos_label=pos_label)
|
||||
assert pr_curve.name == 'precision_recall_curve'
|
||||
|
||||
pr = pr_curve(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
|
||||
|
||||
assert isinstance(pr, tuple)
|
||||
assert len(pr) == 3
|
||||
for tmp in pr:
|
||||
assert isinstance(tmp, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_classes', [1, None])
|
||||
def test_precision(num_classes):
|
||||
precision = Precision(num_classes=num_classes)
|
||||
assert precision.name == 'precision'
|
||||
|
||||
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1])
|
||||
prec = precision(pred=pred, target=target)
|
||||
assert isinstance(prec, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_classes', [1, None])
|
||||
def test_recall(num_classes):
|
||||
recall = Recall(num_classes=num_classes)
|
||||
assert recall.name == 'recall'
|
||||
|
||||
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1])
|
||||
rec = recall(pred=pred, target=target)
|
||||
assert isinstance(rec, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('pos_label', [1, 2])
|
||||
def test_average_precision(pos_label):
|
||||
avg_prec = AveragePrecision(pos_label=pos_label)
|
||||
assert avg_prec.name == 'AP'
|
||||
|
||||
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 0, 1])
|
||||
ap = avg_prec(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
|
||||
assert isinstance(ap, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('pos_label', [0, 1])
|
||||
def test_auroc(pos_label):
|
||||
auroc = AUROC(pos_label=pos_label)
|
||||
assert auroc.name == 'auroc'
|
||||
|
||||
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 1, 0, 1])
|
||||
area = auroc(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
|
||||
assert isinstance(area, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['beta', 'num_classes'], [
|
||||
pytest.param(0., 1),
|
||||
pytest.param(0.5, 1),
|
||||
pytest.param(1., 1),
|
||||
pytest.param(2., 1),
|
||||
pytest.param(0., None),
|
||||
pytest.param(0.5, None),
|
||||
pytest.param(1., None),
|
||||
pytest.param(2., None)
|
||||
])
|
||||
def test_fbeta(beta, num_classes):
|
||||
fbeta = FBeta(beta=beta, num_classes=num_classes)
|
||||
assert fbeta.name == 'fbeta'
|
||||
|
||||
score = fbeta(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]),
|
||||
target=torch.tensor([[0, 0, 1], [1, 0, 1]]))
|
||||
assert isinstance(score, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_classes', [1, None])
|
||||
def test_f1(num_classes):
|
||||
f1 = F1(num_classes=num_classes)
|
||||
assert f1.name == 'f1'
|
||||
|
||||
score = f1(pred=torch.tensor([[0, 1, 1], [1, 0, 1]]),
|
||||
target=torch.tensor([[0, 0, 1], [1, 0, 1]]))
|
||||
assert isinstance(score, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('pos_label', [1, 2])
|
||||
def test_roc(pos_label):
|
||||
roc = ROC(pos_label=pos_label)
|
||||
assert roc.name == 'roc'
|
||||
|
||||
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 4, 3])
|
||||
res = roc(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
|
||||
|
||||
assert isinstance(res, tuple)
|
||||
assert len(res) == 3
|
||||
for tmp in res:
|
||||
assert isinstance(tmp, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_classes', [4, None])
|
||||
def test_multiclass_roc(num_classes):
|
||||
pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||
[0.05, 0.85, 0.05, 0.05],
|
||||
[0.05, 0.05, 0.85, 0.05],
|
||||
[0.05, 0.05, 0.05, 0.85]])
|
||||
target = torch.tensor([0, 1, 3, 2])
|
||||
|
||||
multi_roc = MulticlassROC(num_classes=num_classes)
|
||||
assert multi_roc.name == 'multiclass_roc'
|
||||
|
||||
res = multi_roc(pred, target)
|
||||
assert isinstance(res, tuple)
|
||||
|
||||
if num_classes is not None:
|
||||
assert len(res) == num_classes
|
||||
|
||||
for tmp in res:
|
||||
assert isinstance(tmp, tuple)
|
||||
assert len(tmp) == 3
|
||||
|
||||
for _tmp in tmp:
|
||||
assert isinstance(_tmp, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_classes', [4, None])
|
||||
def test_multiclass_pr(num_classes):
|
||||
pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
|
||||
[0.05, 0.85, 0.05, 0.05],
|
||||
[0.05, 0.05, 0.85, 0.05],
|
||||
[0.05, 0.05, 0.05, 0.85]])
|
||||
target = torch.tensor([0, 1, 3, 2])
|
||||
|
||||
multi_pr = MulticlassPrecisionRecallCurve(num_classes=num_classes)
|
||||
assert multi_pr.name == 'multiclass_precision_recall_curve'
|
||||
|
||||
pr = multi_pr(pred, target)
|
||||
assert isinstance(pr, tuple)
|
||||
|
||||
if num_classes is not None:
|
||||
assert len(pr) == num_classes
|
||||
|
||||
for tmp in pr:
|
||||
assert isinstance(tmp, tuple)
|
||||
assert len(tmp) == 3
|
||||
|
||||
for _tmp in tmp:
|
||||
assert isinstance(_tmp, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('include_background', [True, False])
|
||||
def test_dice_coefficient(include_background):
|
||||
dice_coeff = DiceCoefficient(include_background=include_background)
|
||||
assert dice_coeff.name == 'dice'
|
||||
|
||||
dice = dice_coeff(torch.randint(0, 1, (10, 25, 25)),
|
||||
torch.randint(0, 1, (10, 25, 25)))
|
||||
assert isinstance(dice, torch.Tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('ignore_index', [0, 1, None])
|
||||
def test_iou(ignore_index):
|
||||
iou = IoU(ignore_index=ignore_index)
|
||||
assert iou.name == 'iou'
|
||||
|
||||
score = iou(torch.randint(0, 1, (10, 25, 25)),
|
||||
torch.randint(0, 1, (10, 25, 25)))
|
||||
|
||||
assert isinstance(score, torch.Tensor)
|
|
@ -1,265 +0,0 @@
|
|||
import sys
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import tests.base.develop_utils as tutils
|
||||
from pytorch_lightning.metrics.converters import (
|
||||
_apply_to_inputs,
|
||||
_apply_to_outputs,
|
||||
convert_to_tensor,
|
||||
convert_to_numpy,
|
||||
_numpy_metric_conversion,
|
||||
_tensor_metric_conversion,
|
||||
sync_ddp_if_available,
|
||||
gather_all_tensors_if_available,
|
||||
tensor_metric,
|
||||
numpy_metric
|
||||
)
|
||||
|
||||
|
||||
def test_apply_to_inputs():
|
||||
def apply_fn(inputs, factor):
|
||||
if isinstance(inputs, (float, int)):
|
||||
return inputs * factor
|
||||
elif isinstance(inputs, dict):
|
||||
return {k: apply_fn(v, factor) for k, v in inputs.items()}
|
||||
elif isinstance(inputs, (tuple, list)):
|
||||
return [apply_fn(x, factor) for x in inputs]
|
||||
|
||||
@_apply_to_inputs(apply_fn, factor=2.)
|
||||
def test_fn(*args, **kwargs):
|
||||
return args, kwargs
|
||||
|
||||
for args in [[], [1., 2.]]:
|
||||
for kwargs in [{}, {'a': 1., 'b': 2.}]:
|
||||
result_args, result_kwargs = test_fn(*args, **kwargs)
|
||||
assert isinstance(result_args, (list, tuple))
|
||||
assert isinstance(result_kwargs, dict)
|
||||
assert len(result_args) == len(args)
|
||||
assert len(result_kwargs) == len(kwargs)
|
||||
assert all([k in result_kwargs for k in kwargs.keys()])
|
||||
for arg, result_arg in zip(args, result_args):
|
||||
assert arg * 2. == result_arg
|
||||
|
||||
for key in kwargs.keys():
|
||||
arg = kwargs[key]
|
||||
result_arg = result_kwargs[key]
|
||||
assert arg * 2. == result_arg
|
||||
|
||||
|
||||
def test_apply_to_outputs():
|
||||
def apply_fn(inputs, additional_str):
|
||||
return str(inputs) + additional_str
|
||||
|
||||
@_apply_to_outputs(apply_fn, additional_str='_str')
|
||||
def test_fn(*args, **kwargs):
|
||||
return 'dummy'
|
||||
|
||||
assert test_fn() == 'dummy_str'
|
||||
|
||||
|
||||
def test_convert_to_tensor():
|
||||
for test_item in [1., np.array([1.])]:
|
||||
result_tensor = convert_to_tensor(test_item)
|
||||
assert isinstance(result_tensor, torch.Tensor)
|
||||
assert result_tensor.item() == 1.
|
||||
|
||||
|
||||
def test_convert_to_numpy():
|
||||
for test_item in [1., torch.tensor([1.])]:
|
||||
result = convert_to_numpy(test_item)
|
||||
assert isinstance(result, np.ndarray)
|
||||
assert result.item() == 1.
|
||||
|
||||
|
||||
def test_numpy_metric_conversion():
|
||||
@_numpy_metric_conversion
|
||||
def numpy_test_metric(*args, **kwargs):
|
||||
for arg in args:
|
||||
assert isinstance(arg, np.ndarray)
|
||||
|
||||
for v in kwargs.values():
|
||||
assert isinstance(v, np.ndarray)
|
||||
|
||||
return 5.
|
||||
|
||||
result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.item() == 5.
|
||||
|
||||
|
||||
def test_tensor_metric_conversion():
|
||||
@_tensor_metric_conversion
|
||||
def tensor_test_metric(*args, **kwargs):
|
||||
for arg in args:
|
||||
assert isinstance(arg, torch.Tensor)
|
||||
|
||||
for v in kwargs.values():
|
||||
assert isinstance(v, torch.Tensor)
|
||||
|
||||
return 5.
|
||||
|
||||
result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.item() == 5.
|
||||
|
||||
|
||||
def _setup_ddp(rank, worldsize):
|
||||
import os
|
||||
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
|
||||
# initialize the process group
|
||||
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
|
||||
|
||||
|
||||
def _ddp_test_fn(rank, worldsize, add_offset: bool, reduction_mean=False):
|
||||
_setup_ddp(rank, worldsize)
|
||||
if add_offset:
|
||||
tensor = torch.tensor([float(rank)])
|
||||
else:
|
||||
tensor = torch.tensor([1.], )
|
||||
if reduction_mean:
|
||||
reduced_tensor = sync_ddp_if_available(tensor, reduce_op='avg')
|
||||
|
||||
manual_reduction = sum([i for i in range(dist.get_world_size())]) / dist.get_world_size()
|
||||
assert reduced_tensor.item() == manual_reduction
|
||||
else:
|
||||
reduced_tensor = sync_ddp_if_available(tensor)
|
||||
|
||||
assert reduced_tensor.item() == dist.get_world_size(), \
|
||||
'Sync-Reduce does not work properly with DDP and Tensors'
|
||||
|
||||
|
||||
def _ddp_test_gather_all_tensors(rank, worldsize):
|
||||
_setup_ddp(rank, worldsize)
|
||||
|
||||
tensor = torch.tensor([rank])
|
||||
gather_tensors = gather_all_tensors_if_available(tensor)
|
||||
mannual_tensors = [torch.tensor([i]) for i in range(worldsize)]
|
||||
|
||||
for t1, t2 in zip(gather_tensors, mannual_tensors):
|
||||
assert(t1.equal(t2))
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
|
||||
def test_sync_reduce_ddp():
|
||||
"""Make sure sync-reduce works with DDP"""
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
|
||||
worldsize = 2
|
||||
mp.spawn(_ddp_test_fn, args=(worldsize, False), nprocs=worldsize)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
|
||||
def test_sync_reduce_ddp_mean():
|
||||
"""Make sure sync-reduce works with DDP"""
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
|
||||
worldsize = 2
|
||||
mp.spawn(_ddp_test_fn, args=(worldsize, True, True), nprocs=worldsize)
|
||||
|
||||
|
||||
def test_sync_reduce_simple():
|
||||
"""Make sure sync-reduce works without DDP"""
|
||||
tensor = torch.tensor([1.], device='cpu')
|
||||
|
||||
reduced_tensor = sync_ddp_if_available(tensor)
|
||||
|
||||
assert torch.allclose(tensor, reduced_tensor), \
|
||||
'Sync-Reduce does not work properly without DDP and Tensors'
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
|
||||
def test_gather_all_tensors_ddp():
|
||||
"""Make sure gather_all_tensors works with DDP"""
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
|
||||
worldsize = 2
|
||||
mp.spawn(_ddp_test_gather_all_tensors, args=(worldsize, ), nprocs=worldsize)
|
||||
|
||||
|
||||
def _test_tensor_metric(is_ddp: bool):
|
||||
@tensor_metric()
|
||||
def tensor_test_metric(*args, **kwargs):
|
||||
for arg in args:
|
||||
assert isinstance(arg, torch.Tensor)
|
||||
|
||||
for v in kwargs.values():
|
||||
assert isinstance(v, torch.Tensor)
|
||||
|
||||
return 5.
|
||||
|
||||
if is_ddp:
|
||||
factor = dist.get_world_size()
|
||||
else:
|
||||
factor = 1.
|
||||
|
||||
result = tensor_test_metric(np.array([1.]), dummy_kwarg=2.)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.item() == 5. * factor
|
||||
|
||||
|
||||
def _ddp_test_tensor_metric(rank, worldsize):
|
||||
_setup_ddp(rank, worldsize)
|
||||
_test_tensor_metric(True)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
|
||||
def test_tensor_metric_ddp():
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
|
||||
world_size = 2
|
||||
mp.spawn(_ddp_test_tensor_metric, args=(world_size,), nprocs=world_size)
|
||||
# dist.destroy_process_group()
|
||||
|
||||
|
||||
def test_tensor_metric_simple():
|
||||
_test_tensor_metric(False)
|
||||
|
||||
|
||||
def _test_numpy_metric(is_ddp: bool):
|
||||
@numpy_metric()
|
||||
def numpy_test_metric(*args, **kwargs):
|
||||
for arg in args:
|
||||
assert isinstance(arg, np.ndarray)
|
||||
|
||||
for v in kwargs.values():
|
||||
assert isinstance(v, np.ndarray)
|
||||
|
||||
return 5.
|
||||
|
||||
if is_ddp:
|
||||
factor = dist.get_world_size()
|
||||
else:
|
||||
factor = 1.
|
||||
|
||||
result = numpy_test_metric(torch.tensor([1.]), dummy_kwarg=2.)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert result.item() == 5. * factor
|
||||
|
||||
|
||||
def _ddp_test_numpy_metric(rank, worldsize):
|
||||
_setup_ddp(rank, worldsize)
|
||||
_test_numpy_metric(True)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32" , reason="DDP not available on windows")
|
||||
def test_numpy_metric_ddp():
|
||||
tutils.reset_seed()
|
||||
tutils.set_random_master_port()
|
||||
world_size = 2
|
||||
mp.spawn(_ddp_test_numpy_metric, args=(world_size,), nprocs=world_size)
|
||||
# dist.destroy_process_group()
|
||||
|
||||
|
||||
def test_numpy_metric_simple():
|
||||
_test_numpy_metric(False)
|
|
@ -0,0 +1,45 @@
|
|||
import pytest
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
|
||||
from tests.metrics.test_metric import Dummy
|
||||
from tests.metrics.utils import setup_ddp
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
def _test_ddp_sum(rank, worldsize):
|
||||
setup_ddp(rank, worldsize)
|
||||
dummy = Dummy()
|
||||
dummy._reductions = {"foo": torch.sum}
|
||||
dummy.foo = torch.tensor(1)
|
||||
|
||||
dummy._sync_dist()
|
||||
assert dummy.foo == worldsize
|
||||
|
||||
|
||||
def _test_ddp_cat(rank, worldsize):
|
||||
setup_ddp(rank, worldsize)
|
||||
dummy = Dummy()
|
||||
dummy._reductions = {"foo": torch.cat}
|
||||
dummy.foo = [torch.tensor([1])]
|
||||
dummy._sync_dist()
|
||||
assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1])))
|
||||
|
||||
|
||||
def _test_ddp_sum_cat(rank, worldsize):
|
||||
setup_ddp(rank, worldsize)
|
||||
dummy = Dummy()
|
||||
dummy._reductions = {"foo": torch.cat, "bar": torch.sum}
|
||||
dummy.foo = [torch.tensor([1])]
|
||||
dummy.bar = torch.tensor(1)
|
||||
dummy._sync_dist()
|
||||
assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1])))
|
||||
assert dummy.bar == worldsize
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
|
||||
@pytest.mark.parametrize("process", [_test_ddp_cat, _test_ddp_sum, _test_ddp_sum_cat])
|
||||
def test_ddp(process):
|
||||
torch.multiprocessing.spawn(process, args=(2,), nprocs=2)
|
|
@ -0,0 +1,108 @@
|
|||
import pytest
|
||||
import torch
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
class Dummy(Metric):
|
||||
name = "Dummy"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.add_state("x", torch.tensor(0), dist_reduce_fx=None)
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
def compute(self):
|
||||
pass
|
||||
|
||||
|
||||
def test_inherit():
|
||||
a = Dummy()
|
||||
|
||||
|
||||
def test_add_state():
|
||||
a = Dummy()
|
||||
|
||||
a.add_state("a", torch.tensor(0), "sum")
|
||||
assert a._reductions["a"](torch.tensor([1, 1])) == 2
|
||||
|
||||
a.add_state("b", torch.tensor(0), "mean")
|
||||
assert np.allclose(a._reductions["b"](torch.tensor([1.0, 2.0])).numpy(), 1.5)
|
||||
|
||||
a.add_state("c", torch.tensor(0), "cat")
|
||||
assert a._reductions["c"]([torch.tensor([1]), torch.tensor([1])]).shape == (2,)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
a.add_state("d1", torch.tensor(0), 'xyz')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
a.add_state("d2", torch.tensor(0), 42)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
a.add_state("d3", [torch.tensor(0)], 'sum')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
a.add_state("d4", 42, 'sum')
|
||||
|
||||
def custom_fx(x):
|
||||
return -1
|
||||
|
||||
a.add_state("e", torch.tensor(0), custom_fx)
|
||||
assert a._reductions["e"](torch.tensor([1, 1])) == -1
|
||||
|
||||
|
||||
def test_reset():
|
||||
class A(Dummy):
|
||||
pass
|
||||
|
||||
a = A()
|
||||
assert a.x == 0
|
||||
a.x = torch.tensor(5)
|
||||
a.reset()
|
||||
assert a.x == 0
|
||||
|
||||
|
||||
def test_update():
|
||||
class A(Dummy):
|
||||
def update(self, x):
|
||||
self.x += x
|
||||
|
||||
a = A()
|
||||
assert a.x == 0
|
||||
assert a._computed is None
|
||||
a.update(1)
|
||||
assert a._computed is None
|
||||
assert a.x == 1
|
||||
a.update(2)
|
||||
assert a.x == 3
|
||||
assert a._computed is None
|
||||
|
||||
|
||||
def test_compute():
|
||||
class A(Dummy):
|
||||
def update(self, x):
|
||||
self.x += x
|
||||
|
||||
def compute(self):
|
||||
return self.x
|
||||
|
||||
a = A()
|
||||
assert 0 == a.compute()
|
||||
assert 0 == a.x
|
||||
a.update(1)
|
||||
assert a._computed is None
|
||||
assert a.compute() == 1
|
||||
assert a._computed == 1
|
||||
a.update(2)
|
||||
assert a._computed is None
|
||||
assert a.compute() == 2
|
||||
assert a._computed == 2
|
||||
|
||||
# called without update, should return cached value
|
||||
a._computed = 5
|
||||
assert a.compute() == 5
|
|
@ -1,323 +0,0 @@
|
|||
import os
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import tests.base.develop_utils as tutils
|
||||
from tests.base import EvalModelTemplate
|
||||
from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric
|
||||
from pytorch_lightning import Trainer
|
||||
|
||||
|
||||
class DummyTensorMetric(TensorMetric):
|
||||
def __init__(self):
|
||||
super().__init__("dummy")
|
||||
|
||||
def forward(self, input1, input2):
|
||||
assert isinstance(input1, torch.Tensor)
|
||||
assert isinstance(input2, torch.Tensor)
|
||||
return torch.tensor([1.0])
|
||||
|
||||
|
||||
class DummyNumpyMetric(NumpyMetric):
|
||||
def __init__(self):
|
||||
super().__init__("dummy")
|
||||
|
||||
def forward(self, input1, input2):
|
||||
assert isinstance(input1, np.ndarray)
|
||||
assert isinstance(input2, np.ndarray)
|
||||
return 1.0
|
||||
|
||||
|
||||
class DummyTensorCollectionMetric(TensorMetric):
|
||||
def __init__(self):
|
||||
super().__init__("dummy")
|
||||
|
||||
def forward(self, input1, input2):
|
||||
assert isinstance(input1, torch.Tensor)
|
||||
assert isinstance(input2, torch.Tensor)
|
||||
return 1.0, 2.0, 3.0, 4.0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("metric", [DummyTensorCollectionMetric()])
|
||||
def test_collection_metric(metric: Metric):
|
||||
""" Test that metric.device, metric.dtype works for metric collection """
|
||||
input1, input2 = torch.tensor([1.0]), torch.tensor([2.0])
|
||||
|
||||
def change_and_check_device_dtype(device, dtype):
|
||||
metric.to(device=device, dtype=dtype)
|
||||
|
||||
metric_val = metric(input1, input2)
|
||||
assert not isinstance(metric_val, torch.Tensor)
|
||||
|
||||
if device is not None:
|
||||
assert metric.device in [device, torch.device(device)]
|
||||
|
||||
if dtype is not None:
|
||||
assert metric.dtype == dtype
|
||||
|
||||
devices = [None, "cpu"]
|
||||
if torch.cuda.is_available():
|
||||
devices += ["cuda:0"]
|
||||
|
||||
for device in devices:
|
||||
for dtype in [None, torch.float32, torch.float64]:
|
||||
change_and_check_device_dtype(device=device, dtype=dtype)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
metric.cuda(0)
|
||||
assert metric.device == torch.device("cuda", index=0)
|
||||
|
||||
metric.cpu()
|
||||
assert metric.device == torch.device("cpu")
|
||||
|
||||
metric.type(torch.int8)
|
||||
assert metric.dtype == torch.int8
|
||||
|
||||
metric.float()
|
||||
assert metric.dtype == torch.float32
|
||||
|
||||
metric.double()
|
||||
assert metric.dtype == torch.float64
|
||||
assert all(out.dtype == torch.float64 for out in metric(input1, input2))
|
||||
|
||||
if torch.cuda.is_available():
|
||||
metric.cuda()
|
||||
metric.half()
|
||||
assert metric.dtype == torch.float16
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metric",
|
||||
[
|
||||
DummyTensorMetric(),
|
||||
DummyNumpyMetric(),
|
||||
],
|
||||
)
|
||||
def test_metric(metric: Metric):
|
||||
""" Test that metric.device, metric.dtype works for single metric"""
|
||||
input1, input2 = torch.tensor([1.0]), torch.tensor([2.0])
|
||||
|
||||
def change_and_check_device_dtype(device, dtype):
|
||||
metric.to(device=device, dtype=dtype)
|
||||
|
||||
metric_val = metric(input1, input2)
|
||||
assert isinstance(metric_val, torch.Tensor)
|
||||
|
||||
if device is not None:
|
||||
assert metric.device in [device, torch.device(device)]
|
||||
assert metric_val.device in [device, torch.device(device)]
|
||||
|
||||
if dtype is not None:
|
||||
assert metric.dtype == dtype
|
||||
assert metric_val.dtype == dtype
|
||||
|
||||
devices = [None, "cpu"]
|
||||
if torch.cuda.is_available():
|
||||
devices += ["cuda:0"]
|
||||
|
||||
for device in devices:
|
||||
for dtype in [None, torch.float32, torch.float64]:
|
||||
change_and_check_device_dtype(device=device, dtype=dtype)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
metric.cuda(0)
|
||||
assert metric.device == torch.device("cuda", index=0)
|
||||
assert metric(input1, input2).device == torch.device("cuda", index=0)
|
||||
|
||||
metric.cpu()
|
||||
assert metric.device == torch.device("cpu")
|
||||
assert metric(input1, input2).device == torch.device("cpu")
|
||||
|
||||
metric.float()
|
||||
assert metric.dtype == torch.float32
|
||||
assert metric(input1, input2).dtype == torch.float32
|
||||
|
||||
metric.double()
|
||||
assert metric.dtype == torch.float64
|
||||
assert metric(input1, input2).dtype == torch.float64
|
||||
|
||||
if torch.cuda.is_available():
|
||||
metric.cuda()
|
||||
metric.half()
|
||||
assert metric.dtype == torch.float16
|
||||
assert metric(input1, input2).dtype == torch.float16
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.parametrize("metric", [DummyTensorMetric, DummyNumpyMetric])
|
||||
def test_model_pickable(tmpdir, metric: Metric):
|
||||
"""Make sure that metrics are pickable by including into a model and running in multi-gpu mode"""
|
||||
tutils.set_random_master_port()
|
||||
|
||||
trainer_options = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_train_batches=10,
|
||||
gpus=[0, 1],
|
||||
distributed_backend="ddp_spawn",
|
||||
)
|
||||
|
||||
model = EvalModelTemplate()
|
||||
model.metric = metric()
|
||||
model.training_step = model.training_step__using_metrics
|
||||
|
||||
trainer = Trainer(**trainer_options)
|
||||
result = trainer.fit(model)
|
||||
|
||||
# correct result and ok accuracy
|
||||
assert result == 1, "ddp model failed to complete"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("metric", [DummyTensorMetric(), DummyNumpyMetric()])
|
||||
def test_saving_pickable(tmpdir, metric: Metric):
|
||||
""" Make sure that metrics are pickable by saving and loading them using torch """
|
||||
x, y = torch.randn(10,), torch.randn(
|
||||
10,
|
||||
)
|
||||
results_before_save = metric(x, y)
|
||||
|
||||
# save metric
|
||||
save_path = os.path.join(tmpdir, "save_test.ckpt")
|
||||
torch.save(metric, save_path)
|
||||
|
||||
# load metric
|
||||
new_metric = torch.load(save_path)
|
||||
results_after_load = new_metric(x, y)
|
||||
|
||||
# Check metric value is the same
|
||||
assert results_before_save == results_after_load
|
||||
|
||||
|
||||
def test_correct_call_order():
|
||||
""" Check that hooks are called in the expected order """
|
||||
|
||||
class DummyMetric(Metric):
|
||||
def __init__(self):
|
||||
super().__init__("dummy")
|
||||
self.call_history = ["init"]
|
||||
|
||||
@staticmethod
|
||||
def input_convert(self, data: Any):
|
||||
self.call_history.append("input_convert")
|
||||
return super(DummyMetric, self).input_convert(self, data)
|
||||
|
||||
def forward(self, tensor1, tensor2):
|
||||
self.call_history.append("forward")
|
||||
return tensor1 - tensor2
|
||||
|
||||
@staticmethod
|
||||
def output_convert(self, data: Any, output: Any):
|
||||
self.call_history.append("output_convert")
|
||||
return super(DummyMetric, self).output_convert(self, data, output)
|
||||
|
||||
def ddp_sync(self, tensor: Any):
|
||||
self.call_history.append("ddp_sync")
|
||||
return super().ddp_sync(tensor)
|
||||
|
||||
@staticmethod
|
||||
def ddp_reduce(self, data: Any, output: Any):
|
||||
self.call_history.append("ddp_reduce")
|
||||
return super(DummyMetric, self).ddp_reduce(self, data, output)
|
||||
|
||||
def aggregate(self, *tensors: torch.Tensor) -> torch.Tensor:
|
||||
self.call_history.append("aggregate")
|
||||
return super().aggregate(*tensors)
|
||||
|
||||
def reset(self):
|
||||
self.call_history.append("reset")
|
||||
return super().reset()
|
||||
|
||||
@property
|
||||
def aggregated(self) -> torch.Tensor:
|
||||
self.call_history.append("aggregated")
|
||||
return super().aggregated
|
||||
|
||||
@staticmethod
|
||||
def compute(self, data: Any, output: Any):
|
||||
self.call_history.append("compute")
|
||||
return super(DummyMetric, self).compute(self, data, output)
|
||||
|
||||
metric = DummyMetric()
|
||||
assert metric.call_history == ["init"]
|
||||
result = metric(torch.tensor([2.0]), torch.tensor([1.0]))
|
||||
assert torch.allclose(result, torch.tensor(1.0))
|
||||
assert metric.call_history == [
|
||||
"init",
|
||||
"input_convert",
|
||||
"forward",
|
||||
"output_convert",
|
||||
"ddp_reduce",
|
||||
"ddp_sync",
|
||||
"aggregate",
|
||||
"compute"
|
||||
]
|
||||
aggr = metric.aggregated
|
||||
assert metric.call_history == [
|
||||
"init",
|
||||
"input_convert",
|
||||
"forward",
|
||||
"output_convert",
|
||||
"ddp_reduce",
|
||||
"ddp_sync",
|
||||
"aggregate",
|
||||
"compute",
|
||||
"aggregated",
|
||||
"aggregate",
|
||||
"reset",
|
||||
"compute"
|
||||
]
|
||||
assert torch.allclose(aggr, result)
|
||||
_ = metric(torch.tensor(2.0), torch.tensor(1.0))
|
||||
assert metric.call_history == [
|
||||
"init",
|
||||
"input_convert",
|
||||
"forward",
|
||||
"output_convert",
|
||||
"ddp_reduce",
|
||||
"ddp_sync",
|
||||
"aggregate",
|
||||
"compute",
|
||||
"aggregated",
|
||||
"aggregate",
|
||||
"reset",
|
||||
"compute",
|
||||
"input_convert",
|
||||
"forward",
|
||||
"output_convert",
|
||||
"ddp_reduce",
|
||||
"ddp_sync",
|
||||
"aggregate",
|
||||
"compute"
|
||||
]
|
||||
|
||||
metric = DummyMetric()
|
||||
_ = metric(torch.tensor([2.0]), torch.tensor([1.0]))
|
||||
_ = metric(torch.tensor([3.0]), torch.tensor([0.0]))
|
||||
|
||||
aggregated = metric.aggregated
|
||||
|
||||
assert torch.allclose(aggregated, torch.tensor(4.0))
|
||||
|
||||
assert metric.call_history == [
|
||||
"init",
|
||||
"input_convert",
|
||||
"forward",
|
||||
"output_convert",
|
||||
"ddp_reduce",
|
||||
"ddp_sync",
|
||||
"aggregate",
|
||||
"compute",
|
||||
"input_convert",
|
||||
"forward",
|
||||
"output_convert",
|
||||
"ddp_reduce",
|
||||
"ddp_sync",
|
||||
"aggregate",
|
||||
"compute",
|
||||
"aggregated",
|
||||
"aggregate",
|
||||
"reset",
|
||||
"compute",
|
||||
]
|
|
@ -1,29 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.nlp import BLEUScore
|
||||
|
||||
# example taken from
|
||||
# https://www.nltk.org/api/nltk.translate.html?highlight=bleu%20score#nltk.translate.bleu_score.corpus_bleu
|
||||
HYP1 = "It is a guide to action which ensures that the military always obeys the commands of the party".split()
|
||||
HYP2 = "he read the book because he was interested in world history".split()
|
||||
|
||||
REF1A = "It is a guide to action that ensures that the military will forever heed Party commands".split()
|
||||
REF1B = "It is a guiding principle which makes the military forces always being under the command of the Party".split()
|
||||
REF1C = "It is the practical guide for the army always to heed the directions of the party".split()
|
||||
REF2A = "he was interested in world history because he read the book".split()
|
||||
|
||||
LIST_OF_REFERENCES = [[REF1A, REF1B, REF1C], [REF2A]]
|
||||
HYPOTHESES = [HYP1, HYP2]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["n_gram", "smooth"],
|
||||
[pytest.param(1, True), pytest.param(2, False), pytest.param(3, True), pytest.param(4, False),],
|
||||
)
|
||||
def test_bleu(smooth, n_gram):
|
||||
bleu = BLEUScore(n_gram=n_gram, smooth=smooth)
|
||||
assert bleu.name == "bleu"
|
||||
|
||||
pl_output = bleu(HYPOTHESES, LIST_OF_REFERENCES)
|
||||
assert isinstance(pl_output, torch.Tensor)
|
|
@ -1,69 +0,0 @@
|
|||
# NOTE: This file only tests if modules with arguments are running fine.
|
||||
# The actual metric implementation is tested in functional/test_regression.py
|
||||
# Especially reduction and reducing across processes won't be tested here!
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.regression import (
|
||||
MAE, MSE, RMSE, RMSLE, PSNR, SSIM
|
||||
)
|
||||
|
||||
|
||||
def test_mse():
|
||||
mse = MSE()
|
||||
assert mse.name == 'mse'
|
||||
|
||||
pred = torch.tensor([0., 1, 2, 3])
|
||||
target = torch.tensor([0., 1, 2, 2])
|
||||
score = mse(pred, target)
|
||||
assert isinstance(score, torch.Tensor)
|
||||
|
||||
|
||||
def test_rmse():
|
||||
rmse = RMSE()
|
||||
assert rmse.name == 'rmse'
|
||||
|
||||
pred = torch.tensor([0., 1, 2, 3])
|
||||
target = torch.tensor([0., 1, 2, 2])
|
||||
score = rmse(pred, target)
|
||||
assert isinstance(score, torch.Tensor)
|
||||
|
||||
|
||||
def test_mae():
|
||||
mae = MAE()
|
||||
assert mae.name == 'mae'
|
||||
|
||||
pred = torch.tensor([0., 1, 2, 3])
|
||||
target = torch.tensor([0., 1, 2, 2])
|
||||
score = mae(pred, target)
|
||||
assert isinstance(score, torch.Tensor)
|
||||
|
||||
|
||||
def test_rmsle():
|
||||
rmsle = RMSLE()
|
||||
assert rmsle.name == 'rmsle'
|
||||
|
||||
pred = torch.tensor([0., 1, 2, 3])
|
||||
target = torch.tensor([0., 1, 2, 2])
|
||||
score = rmsle(pred, target)
|
||||
assert isinstance(score, torch.Tensor)
|
||||
|
||||
|
||||
def test_psnr():
|
||||
psnr = PSNR()
|
||||
assert psnr.name == 'psnr'
|
||||
|
||||
pred = torch.tensor([0., 1, 2, 3])
|
||||
target = torch.tensor([0., 1, 2, 2])
|
||||
score = psnr(pred, target)
|
||||
assert isinstance(score, torch.Tensor)
|
||||
|
||||
|
||||
def test_ssim():
|
||||
ssim = SSIM()
|
||||
assert ssim.name == 'ssim'
|
||||
|
||||
pred = torch.rand([16, 1, 16, 16])
|
||||
target = pred * 0.75
|
||||
score = ssim(pred, target)
|
||||
assert isinstance(score, torch.Tensor)
|
|
@ -1,178 +0,0 @@
|
|||
import numbers
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import (
|
||||
accuracy_score as sk_accuracy,
|
||||
precision_score as sk_precision,
|
||||
recall_score as sk_recall,
|
||||
f1_score as sk_f1_score,
|
||||
fbeta_score as sk_fbeta_score,
|
||||
confusion_matrix as sk_confusion_matrix,
|
||||
average_precision_score as sk_average_precision,
|
||||
auc as sk_auc,
|
||||
precision_recall_curve as sk_precision_recall_curve,
|
||||
roc_curve as sk_roc_curve,
|
||||
roc_auc_score as sk_roc_auc_score,
|
||||
balanced_accuracy_score as sk_balanced_accuracy_score,
|
||||
dcg_score as sk_dcg_score,
|
||||
mean_absolute_error as sk_mean_absolute_error,
|
||||
mean_squared_error as sk_mean_squared_error,
|
||||
mean_squared_log_error as sk_mean_squared_log_error,
|
||||
median_absolute_error as sk_median_absolute_error,
|
||||
r2_score as sk_r2_score,
|
||||
mean_poisson_deviance as sk_mean_poisson_deviance,
|
||||
mean_gamma_deviance as sk_mean_gamma_deviance,
|
||||
mean_tweedie_deviance as sk_mean_tweedie_deviance,
|
||||
explained_variance_score as sk_explained_variance_score,
|
||||
cohen_kappa_score as sk_cohen_kappa_score,
|
||||
hamming_loss as sk_hamming_loss,
|
||||
hinge_loss as sk_hinge_loss,
|
||||
jaccard_score as sk_jaccard_score
|
||||
)
|
||||
|
||||
from pytorch_lightning.metrics.converters import convert_to_numpy
|
||||
from pytorch_lightning.metrics.sklearns import (
|
||||
Accuracy,
|
||||
AUC,
|
||||
AveragePrecision,
|
||||
BalancedAccuracy,
|
||||
ConfusionMatrix,
|
||||
CohenKappaScore,
|
||||
DCG,
|
||||
F1,
|
||||
FBeta,
|
||||
Hamming,
|
||||
Hinge,
|
||||
Jaccard,
|
||||
Precision,
|
||||
Recall,
|
||||
PrecisionRecallCurve,
|
||||
ROC,
|
||||
AUROC,
|
||||
MeanAbsoluteError,
|
||||
MeanSquaredError,
|
||||
MeanSquaredLogError,
|
||||
MedianAbsoluteError,
|
||||
R2Score,
|
||||
MeanPoissonDeviance,
|
||||
MeanGammaDeviance,
|
||||
MeanTweedieDeviance,
|
||||
ExplainedVariance,
|
||||
)
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
||||
|
||||
def _xy_only(func):
|
||||
def new_func(*args, **kwargs):
|
||||
return np.array(func(*args, **kwargs)[:2])
|
||||
return new_func
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['metric_class', 'sklearn_func', 'inputs'], [
|
||||
pytest.param(Accuracy(), sk_accuracy,
|
||||
{'y_pred': torch.randint(10, size=(128,)),
|
||||
'y_true': torch.randint(10, size=(128,))},
|
||||
id='Accuracy'),
|
||||
pytest.param(AUC(), sk_auc,
|
||||
{'x': torch.arange(10, dtype=torch.float) / 10,
|
||||
'y': torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.3, 0.5, 0.6, 0.7])},
|
||||
id='AUC'),
|
||||
pytest.param(AveragePrecision(), sk_average_precision,
|
||||
{'y_score': torch.randint(2, size=(128,)),
|
||||
'y_true': torch.randint(2, size=(128,))},
|
||||
id='AveragePrecision'),
|
||||
pytest.param(ConfusionMatrix(), sk_confusion_matrix,
|
||||
{'y_pred': torch.randint(10, size=(128,)),
|
||||
'y_true': torch.randint(10, size=(128,))},
|
||||
id='ConfusionMatrix'),
|
||||
pytest.param(F1(average='macro'), partial(sk_f1_score, average='macro'),
|
||||
{'y_pred': torch.randint(10, size=(128,)),
|
||||
'y_true': torch.randint(10, size=(128,))},
|
||||
id='F1'),
|
||||
pytest.param(FBeta(beta=0.5, average='macro'), partial(sk_fbeta_score, beta=0.5, average='macro'),
|
||||
{'y_pred': torch.randint(10, size=(128,)),
|
||||
'y_true': torch.randint(10, size=(128,))},
|
||||
id='FBeta'),
|
||||
pytest.param(Precision(average='macro'), partial(sk_precision, average='macro'),
|
||||
{'y_pred': torch.randint(10, size=(128,)),
|
||||
'y_true': torch.randint(10, size=(128,))},
|
||||
id='Precision'),
|
||||
pytest.param(Recall(average='macro'), partial(sk_recall, average='macro'),
|
||||
{'y_pred': torch.randint(10, size=(128,)),
|
||||
'y_true': torch.randint(10, size=(128,))},
|
||||
id='Recall'),
|
||||
pytest.param(PrecisionRecallCurve(), _xy_only(sk_precision_recall_curve),
|
||||
{'probas_pred': torch.rand(size=(128,)),
|
||||
'y_true': torch.randint(2, size=(128,))},
|
||||
id='PrecisionRecallCurve'),
|
||||
pytest.param(ROC(), _xy_only(sk_roc_curve),
|
||||
{'y_score': torch.rand(size=(128,)),
|
||||
'y_true': torch.randint(2, size=(128,))},
|
||||
id='ROC'),
|
||||
pytest.param(AUROC(), sk_roc_auc_score,
|
||||
{'y_score': torch.rand(size=(128,)),
|
||||
'y_true': torch.randint(2, size=(128,))},
|
||||
id='AUROC'),
|
||||
pytest.param(BalancedAccuracy(), sk_balanced_accuracy_score,
|
||||
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
|
||||
id='BalancedAccuracy'),
|
||||
pytest.param(DCG(), sk_dcg_score,
|
||||
{'y_score': torch.rand(size=(128, 3)), 'y_true': torch.randint(3, size=(128, 3))},
|
||||
id='DCG'),
|
||||
pytest.param(ExplainedVariance(), sk_explained_variance_score,
|
||||
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
|
||||
id='ExplainedVariance'),
|
||||
pytest.param(MeanAbsoluteError(), sk_mean_absolute_error,
|
||||
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
|
||||
id='MeanAbsolutError'),
|
||||
pytest.param(MeanSquaredError(), sk_mean_squared_error,
|
||||
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
|
||||
id='MeanSquaredError'),
|
||||
pytest.param(MeanSquaredLogError(), sk_mean_squared_log_error,
|
||||
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
|
||||
id='MeanSquaredLogError'),
|
||||
pytest.param(MedianAbsoluteError(), sk_median_absolute_error,
|
||||
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
|
||||
id='MedianAbsoluteError'),
|
||||
pytest.param(R2Score(), sk_r2_score,
|
||||
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
|
||||
id='R2Score'),
|
||||
pytest.param(MeanPoissonDeviance(), sk_mean_poisson_deviance,
|
||||
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
|
||||
id='MeanPoissonDeviance'),
|
||||
pytest.param(MeanGammaDeviance(), sk_mean_gamma_deviance,
|
||||
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
|
||||
id='MeanGammaDeviance'),
|
||||
pytest.param(MeanTweedieDeviance(), sk_mean_tweedie_deviance,
|
||||
{'y_pred': torch.rand(size=(128,)), 'y_true': torch.rand(size=(128,))},
|
||||
id='MeanTweedieDeviance'),
|
||||
pytest.param(CohenKappaScore(), sk_cohen_kappa_score,
|
||||
{'y1': torch.randint(3, size=(128,)), 'y2': torch.randint(3, size=(128,))},
|
||||
id='CohenKappaScore'),
|
||||
pytest.param(Hamming(), sk_hamming_loss,
|
||||
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
|
||||
id='Hamming'),
|
||||
pytest.param(Hinge(), sk_hinge_loss,
|
||||
{'pred_decision': torch.randn(size=(128,)), 'y_true': torch.randint(2, size=(128,))},
|
||||
id='Hinge'),
|
||||
pytest.param(Jaccard(average='macro'), partial(sk_jaccard_score, average='macro'),
|
||||
{'y_pred': torch.randint(10, size=(128,)), 'y_true': torch.randint(10, size=(128,))},
|
||||
id='Jaccard')
|
||||
])
|
||||
def test_sklearn_metric(metric_class, sklearn_func, inputs):
|
||||
numpy_inputs = apply_to_collection(inputs, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)
|
||||
|
||||
sklearn_result = sklearn_func(**numpy_inputs)
|
||||
lightning_result = metric_class(**inputs)
|
||||
|
||||
sklearn_result = apply_to_collection(
|
||||
sklearn_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy)
|
||||
|
||||
lightning_result = np.array(apply_to_collection(
|
||||
lightning_result, (torch.Tensor, np.ndarray, numbers.Number), convert_to_numpy))
|
||||
|
||||
assert np.allclose(sklearn_result, lightning_result, atol=1e-5)
|
||||
assert isinstance(lightning_result, type(sklearn_result))
|
|
@ -0,0 +1,61 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
NUM_PROCESSES = 2
|
||||
NUM_BATCHES = 10
|
||||
BATCH_SIZE = 16
|
||||
|
||||
|
||||
def setup_ddp(rank, world_size):
|
||||
os.environ["MASTER_ADDR"] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '8088'
|
||||
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
|
||||
|
||||
|
||||
def _compute_batch(rank, preds, target, metric_class, sk_metric, ddp_sync_on_step, worldsize=1, metric_args={}):
|
||||
|
||||
metric = metric_class(compute_on_step=True, ddp_sync_on_step=ddp_sync_on_step, **metric_args)
|
||||
|
||||
# Only use ddp if world size
|
||||
if worldsize > 1:
|
||||
setup_ddp(rank, worldsize)
|
||||
|
||||
for i in range(rank, NUM_BATCHES, worldsize):
|
||||
batch_result = metric(preds[i], target[i])
|
||||
|
||||
if metric.ddp_sync_on_step:
|
||||
if rank == 0:
|
||||
ddp_preds = torch.stack([preds[i + r] for r in range(worldsize)])
|
||||
ddp_target = torch.stack([target[i + r] for r in range(worldsize)])
|
||||
sk_batch_result = sk_metric(ddp_preds, ddp_target)
|
||||
assert np.allclose(batch_result.numpy(), sk_batch_result)
|
||||
else:
|
||||
sk_batch_result = sk_metric(preds[i], target[i])
|
||||
assert np.allclose(batch_result.numpy(), sk_batch_result)
|
||||
|
||||
# check on all batches on all ranks
|
||||
result = metric.compute()
|
||||
assert isinstance(result, torch.Tensor)
|
||||
|
||||
total_preds = torch.stack([preds[i] for i in range(NUM_BATCHES)])
|
||||
total_target = torch.stack([target[i] for i in range(NUM_BATCHES)])
|
||||
sk_result = sk_metric(total_preds, total_target)
|
||||
|
||||
assert np.allclose(result.numpy(), sk_result)
|
||||
|
||||
|
||||
def compute_batch(preds, target, metric_class, sk_metric, ddp_sync_on_step, ddp=False, metric_args={}):
|
||||
if ddp:
|
||||
if sys.platform == "win32":
|
||||
pytest.skip("DDP not supported on windows")
|
||||
|
||||
torch.multiprocessing.spawn(
|
||||
_compute_batch, args=(preds, target, metric_class, sk_metric, ddp_sync_on_step, NUM_PROCESSES, metric_args),
|
||||
nprocs=NUM_PROCESSES
|
||||
)
|
||||
else:
|
||||
# first args: rank, last args: world size
|
||||
_compute_batch(0, preds, target, metric_class, sk_metric, ddp_sync_on_step, 1, metric_args)
|
Loading…
Reference in New Issue