lightning/docs/source/ecosystem/metrics.rst

92 lines
2.4 KiB
ReStructuredText

TorchMetrics
============
`TorchMetrics <https://torchmetrics.readthedocs.io>`_ is a collection of machine learning metrics for distributed,
scalable PyTorch models and an easy-to-use API to create custom metrics. It has a collection of 60+ PyTorch metrics implementations and
is rigorously tested for all edge cases.
.. code-block:: bash
pip install torchmetrics
In TorchMetrics, we offer the following benefits:
- A standardized interface to increase reproducibility
- Reduced Boilerplate
- Distributed-training compatible
- Rigorously tested
- Automatic accumulation over batches
- Automatic synchronization across multiple devices
-----------------
Example 1: Functional Metrics
-----------------------------
Below is a simple example for calculating the accuracy using the functional interface:
.. code-block:: python
import torch
import torchmetrics
# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
acc = torchmetrics.functional.accuracy(preds, target)
------------
Example 2: Module Metrics
-------------------------
The example below shows how to use the class-based interface:
.. code-block:: python
import torch
import torchmetrics
# initialize metric
metric = torchmetrics.Accuracy()
n_batches = 10
for i in range(n_batches):
# simulate a classification problem
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
# metric on current batch
acc = metric(preds, target)
print(f"Accuracy on batch {i}: {acc}")
# metric on all batches using custom accumulation
acc = metric.compute()
print(f"Accuracy on all data: {acc}")
# Reseting internal state such that metric ready for new data
metric.reset()
------------
Example 3: TorchMetrics with Lightning
--------------------------------------
The example below shows how to use a metric in your :doc:`LightningModule <../common/lightning_module>`:
.. code-block:: python
class MyModel(LightningModule):
def __init__(self):
...
self.accuracy = torchmetrics.Accuracy()
def training_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
...
# log step metric
self.accuracy(preds, y)
self.log("train_acc_step", self.accuracy, on_epoch=True)
...