2020-06-16 11:42:56 +00:00
|
|
|
.. testsetup:: *
|
|
|
|
|
2020-06-17 21:44:11 +00:00
|
|
|
import torch
|
2020-06-16 11:42:56 +00:00
|
|
|
from torch.nn import Module
|
|
|
|
from pytorch_lightning.core.lightning import LightningModule
|
2020-10-06 21:03:24 +00:00
|
|
|
from pytorch_lightning.metrics import Metric
|
2020-06-16 11:42:56 +00:00
|
|
|
|
2020-08-13 22:56:51 +00:00
|
|
|
.. _metrics:
|
|
|
|
|
2020-06-16 11:42:56 +00:00
|
|
|
Metrics
|
|
|
|
=======
|
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
``pytorch_lightning.metrics`` is a Metrics API created for easy metric development and usage in
|
|
|
|
PyTorch and PyTorch Lightning. It is rigorously tested for all edge cases and includes a growing list of
|
|
|
|
common metric implementations.
|
2020-06-16 11:42:56 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
The metrics API provides ``update()``, ``compute()``, ``reset()`` functions to the user. The metric base class inherits
|
|
|
|
``nn.Module`` which allows us to call ``metric(...)`` directly. The ``forward()`` method of the base ``Metric`` class
|
|
|
|
serves the dual purpose of calling ``update()`` on its input and simultanously returning the value of the metric over the
|
|
|
|
provided input.
|
2020-06-16 11:42:56 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
These metrics work with DDP in PyTorch and PyTorch Lightning by default. When ``.compute()`` is called in
|
|
|
|
distributed mode, the internal state of each metric is synced and reduced across each process, so that the
|
|
|
|
logic present in ``.compute()`` is applied to state information from all processes.
|
2020-06-16 11:42:56 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
The example below shows how to use a metric in your ``LightningModule``:
|
2020-06-16 11:42:56 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
.. note::
|
2020-06-16 11:42:56 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
For v0.10.0 the user is expected to call ``.compute()`` on the metric at the end each epoch.
|
|
|
|
This has been shown in the example below. For v1.0 release, we will integrate metrics
|
|
|
|
with logging and ``.compute()`` will be called automatically by PyTorch Lightning.
|
2020-06-17 14:53:48 +00:00
|
|
|
|
2020-06-17 21:44:11 +00:00
|
|
|
.. code-block:: python
|
2020-06-17 14:53:48 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
def __init__(self):
|
|
|
|
...
|
|
|
|
self.accuracy = pl.metrics.Accuracy()
|
|
|
|
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
logits = self(x)
|
|
|
|
...
|
|
|
|
# log step metric
|
|
|
|
self.log('train_acc_step', self.accuracy(logits, y))
|
|
|
|
...
|
|
|
|
|
|
|
|
def training_epoch_end(self, outs):
|
|
|
|
# log epoch metric
|
|
|
|
self.log('train_acc_epoch', self.accuracy.compute())
|
2020-06-17 14:53:48 +00:00
|
|
|
|
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:
|
2020-06-17 14:53:48 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
from pytorch_lightning import metrics
|
|
|
|
|
|
|
|
train_accuracy = metrics.Accuracy()
|
|
|
|
valid_accuracy = metrics.Accuracy(compute_on_step=False)
|
|
|
|
|
|
|
|
for epoch in range(epochs):
|
|
|
|
for x, y in train_data:
|
|
|
|
y_hat = model(x)
|
|
|
|
|
|
|
|
# training step accuracy
|
|
|
|
batch_acc = train_accuracy(y_hat, y)
|
|
|
|
|
|
|
|
for x, y in valid_data:
|
|
|
|
y_hat = model(x)
|
|
|
|
valid_accuracy(y_hat, y)
|
|
|
|
|
|
|
|
# total accuracy over all training batches
|
|
|
|
total_train_accuracy = train_accuracy.compute()
|
|
|
|
|
|
|
|
# total accuracy over all validation batches
|
|
|
|
total_valid_accuracy = train_accuracy.compute()
|
2020-06-16 11:42:56 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
Implementing a Metric
|
2020-06-16 11:42:56 +00:00
|
|
|
---------------------
|
2020-06-18 13:06:31 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
To implement your custom metric, subclass the base ``Metric`` class and implement the following methods:
|
2020-06-17 11:34:39 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
- ``__init__()``: Each state variable should be called using ``self.add_state(...)``.
|
|
|
|
- ``update()``: Any code needed to update the state given any inputs to the metric.
|
|
|
|
- ``compute()``: Computes a final value from the state of the metric.
|
2020-06-17 11:34:39 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
All you need to do is call add_state correctly to implement a custom metric with DDP.
|
|
|
|
``reset()`` is called on metric state variables added using ``add_state()``.
|
2020-07-22 13:58:24 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
To see how metric states are synchronized across distributed processes, refer to ``add_state()`` docs
|
|
|
|
from the base ``Metric`` class.
|
2020-06-17 11:34:39 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
Example implementation:
|
2020-06-17 11:34:39 +00:00
|
|
|
|
|
|
|
.. code-block:: python
|
2020-07-22 13:58:24 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
from pytorch_lightning.metrics import Metric
|
2020-07-22 13:58:24 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
class MyAccuracy(Metric):
|
|
|
|
def __init__(self, ddp_sync_on_step=False):
|
|
|
|
super().__init__(ddp_sync_on_step=ddp_sync_on_step)
|
2020-06-17 11:34:39 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
|
|
|
|
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
|
2020-06-17 11:34:39 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
|
|
|
preds, target = self._input_format(preds, target)
|
|
|
|
assert preds.shape == target.shape
|
2020-06-17 11:34:39 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
self.correct += torch.sum(preds == target)
|
|
|
|
self.total += target.numel()
|
2020-06-17 11:34:39 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
def compute(self):
|
|
|
|
return self.correct.float() / self.total
|
2020-08-05 09:32:53 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
Metric
|
|
|
|
^^^^^^
|
2020-06-17 11:34:39 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
.. autoclass:: pytorch_lightning.metrics.Metric
|
2020-06-17 11:34:39 +00:00
|
|
|
:noindex:
|
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
Classification Metrics
|
|
|
|
----------------------
|
2020-06-17 11:34:39 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
Accuracy
|
2020-06-17 11:34:39 +00:00
|
|
|
^^^^^^^^
|
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
.. autoclass:: pytorch_lightning.metrics.classification.Accuracy
|
2020-08-05 09:32:53 +00:00
|
|
|
:noindex:
|
2020-09-01 18:59:33 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
Regression Metrics
|
|
|
|
------------------
|
2020-09-01 18:59:33 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
MeanSquaredError
|
|
|
|
^^^^^^^^^^^^^^^^
|
2020-08-05 09:32:53 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError
|
2020-08-05 09:32:53 +00:00
|
|
|
:noindex:
|
2020-09-01 18:59:33 +00:00
|
|
|
|
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
MeanAbsoluteError
|
|
|
|
^^^^^^^^^^^^^^^^^
|
2020-08-05 09:32:53 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
.. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError
|
2020-08-05 09:32:53 +00:00
|
|
|
:noindex:
|
|
|
|
|
2020-09-01 18:59:33 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
MeanSquaredLogError
|
|
|
|
^^^^^^^^^^^^^^^^^^^
|
2020-08-05 09:32:53 +00:00
|
|
|
|
2020-10-06 21:03:24 +00:00
|
|
|
.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError
|
2020-08-05 09:32:53 +00:00
|
|
|
:noindex:
|