2020-06-16 11:42:56 +00:00
.. testsetup :: *
2020-06-17 21:44:11 +00:00
import torch
2020-06-16 11:42:56 +00:00
from torch.nn import Module
from pytorch_lightning.core.lightning import LightningModule
2020-10-06 21:03:24 +00:00
from pytorch_lightning.metrics import Metric
2020-06-16 11:42:56 +00:00
2020-08-13 22:56:51 +00:00
.. _metrics:
2020-10-12 22:13:58 +00:00
#######
2020-06-16 11:42:56 +00:00
Metrics
2020-10-12 22:13:58 +00:00
#######
2020-06-16 11:42:56 +00:00
2020-10-06 21:03:24 +00:00
`` pytorch_lightning.metrics `` is a Metrics API created for easy metric development and usage in
PyTorch and PyTorch Lightning. It is rigorously tested for all edge cases and includes a growing list of
common metric implementations.
2020-06-16 11:42:56 +00:00
2020-10-06 21:03:24 +00:00
The metrics API provides `` update() `` , `` compute() `` , `` reset() `` functions to the user. The metric base class inherits
`` nn.Module `` which allows us to call `` metric(...) `` directly. The `` forward() `` method of the base `` Metric `` class
2020-12-08 21:27:43 +00:00
serves the dual purpose of calling `` update() `` on its input and simultaneously returning the value of the metric over the
2020-10-06 21:03:24 +00:00
provided input.
2020-06-16 11:42:56 +00:00
2021-01-13 09:01:08 +00:00
.. warning ::
From v1.2 onward `` compute() `` will no longer automatically call `` reset() `` ,
and it is up to the user to reset metrics between epochs, except in the case where the
2021-02-22 08:50:59 +00:00
metric is directly passed to `` LightningModule `` 's `` self.log `` .
2021-01-13 09:01:08 +00:00
2020-10-06 21:03:24 +00:00
These metrics work with DDP in PyTorch and PyTorch Lightning by default. When `` .compute() `` is called in
distributed mode, the internal state of each metric is synced and reduced across each process, so that the
logic present in `` .compute() `` is applied to state information from all processes.
2020-06-16 11:42:56 +00:00
2020-10-06 21:03:24 +00:00
The example below shows how to use a metric in your `` LightningModule `` :
2020-06-16 11:42:56 +00:00
2020-06-17 21:44:11 +00:00
.. code-block :: python
2020-06-17 14:53:48 +00:00
2020-10-06 21:03:24 +00:00
def __init__(self):
...
self.accuracy = pl.metrics.Accuracy()
2020-10-21 22:05:59 +00:00
2020-10-06 21:03:24 +00:00
def training_step(self, batch, batch_idx):
2021-01-04 15:40:01 +00:00
x, y = batch
preds = self(x)
2020-10-06 21:03:24 +00:00
...
# log step metric
2021-01-04 15:40:01 +00:00
self.log('train_acc_step', self.accuracy(preds, y))
2020-10-06 21:03:24 +00:00
...
2020-10-21 22:05:59 +00:00
2020-10-06 21:03:24 +00:00
def training_epoch_end(self, outs):
# log epoch metric
self.log('train_acc_epoch', self.accuracy.compute())
2020-06-17 14:53:48 +00:00
2020-10-08 02:54:32 +00:00
`` Metric `` objects can also be directly logged, in which case Lightning will log
the metric based on `` on_step `` and `` on_epoch `` flags present in `` self.log(...) `` .
If `` on_epoch `` is True, the logger automatically logs the end of epoch metric value by calling
`` .compute() `` .
.. note ::
`` sync_dist `` , `` sync_dist_op `` , `` sync_dist_group `` , `` reduce_fx `` and `` tbptt_reduce_fx ``
flags from `` self.log(...) `` don't affect the metric logging in any manner. The metric class
contains its own distributed synchronization logic.
This however is only true for metrics that inherit the base class `` Metric `` ,
and thus the functional metric API provides no support for in-built distributed synchronization
or reduction functions.
2020-10-21 22:05:59 +00:00
2020-10-08 02:54:32 +00:00
.. code-block :: python
def __init__(self):
...
self.train_acc = pl.metrics.Accuracy()
self.valid_acc = pl.metrics.Accuracy()
2020-10-21 22:05:59 +00:00
2020-10-08 02:54:32 +00:00
def training_step(self, batch, batch_idx):
2021-01-04 15:40:01 +00:00
x, y = batch
preds = self(x)
2020-10-08 02:54:32 +00:00
...
2021-01-04 15:40:01 +00:00
self.train_acc(preds, y)
2020-10-08 02:54:32 +00:00
self.log('train_acc', self.train_acc, on_step=True, on_epoch=False)
def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_acc(logits, y)
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
2020-11-09 10:30:28 +00:00
.. note ::
2021-01-08 10:09:07 +00:00
2020-11-09 10:30:28 +00:00
If using metrics in data parallel mode (dp), the metric update/logging should be done
in the `` <mode>_step_end `` method (where `` <mode> `` is either `` training `` , `` validation ``
or `` test `` ). This is due to metric states else being destroyed after each forward pass,
leading to wrong accumulation. In practice do the following:
.. code-block :: python
def training_step(self, batch, batch_idx):
data, target = batch
2021-01-04 15:40:01 +00:00
preds = self(data)
2020-11-09 10:30:28 +00:00
...
return {'loss' : loss, 'preds' : preds, 'target' : target}
def training_step_end(self, outputs):
#update and log
self.metric(outputs['preds'], outputs['target'])
self.log('metric', self.metric)
2020-10-06 21:03:24 +00:00
This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:
2020-06-17 14:53:48 +00:00
.. code-block :: python
2020-10-06 21:03:24 +00:00
from pytorch_lightning import metrics
train_accuracy = metrics.Accuracy()
valid_accuracy = metrics.Accuracy(compute_on_step=False)
for epoch in range(epochs):
for x, y in train_data:
y_hat = model(x)
2020-10-21 22:05:59 +00:00
2020-10-06 21:03:24 +00:00
# training step accuracy
batch_acc = train_accuracy(y_hat, y)
2020-10-21 22:05:59 +00:00
2020-10-06 21:03:24 +00:00
for x, y in valid_data:
y_hat = model(x)
valid_accuracy(y_hat, y)
2020-10-21 22:05:59 +00:00
2020-10-06 21:03:24 +00:00
# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()
2020-10-21 22:05:59 +00:00
2020-10-06 21:03:24 +00:00
# total accuracy over all validation batches
2020-10-08 10:40:42 +00:00
total_valid_accuracy = valid_accuracy.compute()
2020-06-16 11:42:56 +00:00
2020-10-28 15:42:52 +00:00
.. note ::
Metrics contain internal states that keep track of the data seen so far.
Do not mix metric states across training, validation and testing.
It is highly recommended to re-initialize the metric per mode as
2021-01-08 10:09:07 +00:00
shown in the examples above. For easy initializing the same metric multiple
times, the `` .clone() `` method can be used:
.. testcode ::
2021-01-26 09:44:54 +00:00
from pytorch_lightning.metrics import Accuracy
2021-01-08 10:09:07 +00:00
def __init__(self):
...
2021-01-26 09:44:54 +00:00
metric = Accuracy()
2021-01-08 10:09:07 +00:00
self.train_acc = metric.clone()
self.val_acc = metric.clone()
self.test_acc = metric.clone()
2020-10-28 15:42:52 +00:00
2020-11-10 08:16:31 +00:00
.. note ::
2020-11-16 12:33:45 +00:00
Metric states are **not** added to the models `` state_dict `` by default.
To change this, after initializing the metric, the method `` .persistent(mode) `` can
2020-11-10 08:16:31 +00:00
be used to enable (`` mode=True `` ) or disable (`` mode=False `` ) this behaviour.
2020-12-29 09:06:28 +00:00
***** ***** ***** *** *
Metrics and devices
***** ***** ***** *** *
Metrics are simple subclasses of :class: `~torch.nn.Module` and their metric states behave
similar to buffers and parameters of modules. This means that metrics states should
be moved to the same device as the input of the metric:
.. code-block :: python
from pytorch_lightning.metrics import Accuracy
target = torch.tensor([1, 1, 0, 0], device=torch.device("cuda", 0))
preds = torch.tensor([0, 1, 0, 0], device=torch.device("cuda", 0))
# Metric states are always initialized on cpu, and needs to be moved to
# the correct device
confmat = Accuracy(num_classes=2).to(torch.device("cuda", 0))
out = confmat(preds, target)
print(out.device) # cuda:0
However, when **properly defined** inside a :class: `~pytorch_lightning.core.lightning.LightningModule`
, Lightning will automatically move the metrics to the same device as the data. Being
**properly defined** means that the metric is correctly identified as a child module of the
model (check `` .children() `` attribute of the model). Therefore, metrics cannot be placed
in native python `` list `` and `` dict `` , as they will not be correctly identified
as child modules. Instead of `` list `` use :class: `~torch.nn.ModuleList` and instead of
`` dict `` use :class: `~torch.nn.ModuleDict` .
.. testcode ::
2021-01-26 09:44:54 +00:00
from pytorch_lightning.metrics import Accuracy
2020-12-29 09:06:28 +00:00
class MyModule(LightningModule):
def __init__(self):
...
# valid ways metrics will be identified as child modules
2021-01-26 09:44:54 +00:00
self.metric1 = Accuracy()
self.metric2 = nn.ModuleList(Accuracy())
self.metric3 = nn.ModuleDict({'accuracy': Accuracy()})
2020-12-29 09:06:28 +00:00
def training_step(self, batch, batch_idx):
# all metrics will be on the same device as the input batch
data, target = batch
preds = self(data)
...
val1 = self.metric1(preds, target)
val2 = self.metric2[0](preds, target)
val3 = self.metric3['accuracy'](preds, target)
2020-10-12 22:13:58 +00:00
***** ***** ***** ***** *
2020-10-06 21:03:24 +00:00
Implementing a Metric
2020-10-12 22:13:58 +00:00
***** ***** ***** ***** *
2020-06-18 13:06:31 +00:00
2020-10-06 21:03:24 +00:00
To implement your custom metric, subclass the base `` Metric `` class and implement the following methods:
2020-06-17 11:34:39 +00:00
2020-10-06 21:03:24 +00:00
- `` __init__() `` : Each state variable should be called using `` self.add_state(...) `` .
- `` update() `` : Any code needed to update the state given any inputs to the metric.
- `` compute() `` : Computes a final value from the state of the metric.
2020-06-17 11:34:39 +00:00
2020-10-07 18:25:52 +00:00
All you need to do is call `` add_state `` correctly to implement a custom metric with DDP.
2020-10-06 21:03:24 +00:00
`` reset() `` is called on metric state variables added using `` add_state() `` .
2020-07-22 13:58:24 +00:00
2020-10-06 21:03:24 +00:00
To see how metric states are synchronized across distributed processes, refer to `` add_state() `` docs
from the base `` Metric `` class.
2020-06-17 11:34:39 +00:00
2020-10-06 21:03:24 +00:00
Example implementation:
2020-06-17 11:34:39 +00:00
2021-01-26 09:44:54 +00:00
.. testcode ::
2020-07-22 13:58:24 +00:00
2020-10-06 21:03:24 +00:00
from pytorch_lightning.metrics import Metric
2020-07-22 13:58:24 +00:00
2020-10-06 21:03:24 +00:00
class MyAccuracy(Metric):
2020-10-10 16:31:00 +00:00
def __init__(self, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
2020-06-17 11:34:39 +00:00
2020-10-06 21:03:24 +00:00
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
2020-06-17 11:34:39 +00:00
2020-10-06 21:03:24 +00:00
def update(self, preds: torch.Tensor, target: torch.Tensor):
preds, target = self._input_format(preds, target)
assert preds.shape == target.shape
2020-06-17 11:34:39 +00:00
2020-10-06 21:03:24 +00:00
self.correct += torch.sum(preds == target)
self.total += target.numel()
2020-06-17 11:34:39 +00:00
2020-10-06 21:03:24 +00:00
def compute(self):
return self.correct.float() / self.total
2020-08-05 09:32:53 +00:00
2020-11-02 19:44:49 +00:00
Metrics support backpropagation, if all computations involved in the metric calculation
are differentiable. However, note that the cached state is detached from the computational
graph and cannot be backpropagated. Not doing this would mean storing the computational
graph for each update call, which can lead to out-of-memory errors.
In practise this means that:
.. code-block :: python
metric = MyMetric()
val = metric(pred, target) # this value can be backpropagated
val = metric.compute() # this value cannot be backpropagated
2021-02-16 21:14:30 +00:00
Metric API
----------
.. autoclass :: pytorch_lightning.metrics.Metric
:noindex:
Internal implementation details
-------------------------------
This section briefly describe how metrics work internally. We encourage looking at the source code for more info.
Internally, Lightning wraps the user defined `` update() `` and `` compute() `` method. We do this to automatically
synchronize and reduce metric states across multiple devices. More precisely, calling `` update() `` does the
following internally:
1. Clears computed cache
2. Calls user-defined `` update() ``
Simiarly, calling `` compute() `` does the following internally
1. Syncs metric states between processes
2. Reduce gathered metric states
3. Calls the user defined `` compute() `` method on the gathered metric states
4. Cache computed result
From a user's standpoint this has one important side-effect: computed results are cached. This means that no
matter how many times `` compute `` is called after one and another, it will continue to return the same result.
The cache is first emptied on the next call to `` update `` .
`` forward `` serves the dual purpose of both returning the metric on the current data and updating the internal
metric state for accumulating over multiple batches. The `` forward() `` method achives this by combining calls
to `` update `` and `` compute `` in the following way (assuming metric is initialized with `` compute_on_step=True `` ):
1. Calls `` update() `` to update the global metric states (for accumulation over multiple batches)
2. Caches the global state
3. Calls `` reset() `` to clear global metric state
4. Calls `` update() `` to update local metric state
5. Calls `` compute() `` to calculate metric for current batch
6. Restores the global state
This procedure has the consequence of calling the user defined `` update `` **twice** during a single
forward call (one to update global statistics and one for getting the batch statistics).
2021-01-26 16:56:12 +00:00
***** ***** ***** ***
Metric Arithmetics
***** ***** ***** ***
Metrics support most of python built-in operators for arithmetic, logic and bitwise operations.
For example for a metric that should return the sum of two different metrics, implementing a new metric is an overhead that is not necessary.
It can now be done with:
.. code-block :: python
first_metric = MyFirstMetric()
second_metric = MySecondMetric()
new_metric = first_metric + second_metric
`` new_metric.update(*args, **kwargs) `` now calls update of `` first_metric `` and `` second_metric `` . It forwards all positional arguments but
forwards only the keyword arguments that are available in respective metric's update declaration.
Similarly `` new_metric.compute() `` now calls compute of `` first_metric `` and `` second_metric `` and adds the results up.
This pattern is implemented for the following operators (with `` a `` being metrics and `` b `` being metrics, tensors, integer or floats):
* Addition (`` a + b `` )
* Bitwise AND (`` a & b `` )
* Equality (`` a == b `` )
* Floordivision (`` a // b `` )
* Greater Equal (`` a >= b `` )
* Greater (`` a > b `` )
* Less Equal (`` a <= b `` )
* Less (`` a < b `` )
* Matrix Multiplication (`` a @ b `` )
* Modulo (`` a % b `` )
* Multiplication (`` a * b `` )
* Inequality (`` a != b `` )
* Bitwise OR (`` a | b `` )
* Power (`` a ** b `` )
* Substraction (`` a - b `` )
* True Division (`` a / b `` )
* Bitwise XOR (`` a ^ b `` )
* Absolute Value (`` abs(a) `` )
* Inversion (`` ~a `` )
* Negative Value (`` neg(a) `` )
* Positive Value (`` pos(a) `` )
2021-01-08 10:09:07 +00:00
***** ***** ***** *
MetricCollection
***** ***** ***** *
In many cases it is beneficial to evaluate the model output by multiple metrics.
In this case the `MetricCollection` class may come in handy. It accepts a sequence
of metrics and wraps theses into a single callable metric class, with the same
interface as any other metric.
Example:
.. testcode ::
from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
metric_collection = MetricCollection([
Accuracy(),
Precision(num_classes=3, average='macro'),
Recall(num_classes=3, average='macro')
])
print(metric_collection(preds, target))
.. testoutput ::
:options: +NORMALIZE_WHITESPACE
2021-01-26 09:44:54 +00:00
{'Accuracy': tensor(0.1250),
'Precision': tensor(0.0667),
2021-01-08 10:09:07 +00:00
'Recall': tensor(0.1111)}
Similarly it can also reduce the amount of code required to log multiple metrics
inside your LightningModule
.. code-block :: python
def __init__(self):
...
metrics = pl.metrics.MetricCollection(...)
self.train_metrics = metrics.clone()
self.valid_metrics = metrics.clone()
def training_step(self, batch, batch_idx):
logits = self(x)
...
self.train_metrics(logits, y)
# use log_dict instead of log
self.log_dict(self.train_metrics, on_step=True, on_epoch=False, prefix='train')
def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_metrics(logits, y)
# use log_dict instead of log
self.log_dict(self.valid_metrics, on_step=True, on_epoch=True, prefix='val')
.. note ::
`MetricCollection` as default assumes that all the metrics in the collection
have the same call signature. If this is not the case, input that should be
given to different metrics can given as keyword arguments to the collection.
.. autoclass :: pytorch_lightning.metrics.MetricCollection
:noindex:
2020-11-02 19:44:49 +00:00
2020-06-17 11:34:39 +00:00
2020-12-07 16:49:35 +00:00
***** ***** ***** ***** ***** **
Class vs Functional Metrics
***** ***** ***** ***** ***** **
2020-10-12 22:13:58 +00:00
2020-12-07 16:49:35 +00:00
The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs.
Also, the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface.
If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also using the class interface.
***** ***** ***** ***** **
2020-10-06 21:03:24 +00:00
Classification Metrics
2020-12-07 16:49:35 +00:00
***** ***** ***** ***** **
2020-06-17 11:34:39 +00:00
2020-12-07 16:49:35 +00:00
Input types
-----------
2020-06-17 11:34:39 +00:00
2021-01-01 11:23:19 +00:00
For the purposes of classification metrics, inputs (predictions and targets) are split
2020-12-07 16:49:35 +00:00
into these categories (`` N `` stands for the batch size and `` C `` for number of classes):
2020-09-01 18:59:33 +00:00
2020-12-07 16:49:35 +00:00
.. csv-table :: \*dtype `` binary `` means integers that are either 0 or 1
:header: "Type", "preds shape", "preds dtype", "target shape", "target dtype"
:widths: 20, 10, 10, 10, 10
2020-10-10 16:31:00 +00:00
2020-12-07 16:49:35 +00:00
"Binary", "(N,)", "`` float `` ", "(N,)", "`` binary ` ` \*"
"Multi-class", "(N,)", "`` int `` ", "(N,)", "`` int `` "
"Multi-class with probabilities", "(N, C)", "`` float `` ", "(N,)", "`` int `` "
"Multi-label", "(N, ...)", "`` float `` ", "(N, ...)", "`` binary ` ` \*"
"Multi-dimensional multi-class", "(N, ...)", "`` int `` ", "(N, ...)", "`` int `` "
"Multi-dimensional multi-class with probabilities", "(N, C, ...)", "`` float `` ", "(N, ...)", "`` int `` "
2020-10-10 16:31:00 +00:00
2020-12-07 16:49:35 +00:00
.. note ::
2021-01-01 11:23:19 +00:00
All dimensions of size 1 (except `` N `` ) are "squeezed out" at the beginning, so
2020-12-07 16:49:35 +00:00
that, for example, a tensor of shape `` (N, 1) `` is treated as `` (N, ) `` .
2020-10-10 16:31:00 +00:00
2021-01-01 11:23:19 +00:00
When predictions or targets are integers, it is assumed that class labels start at 0, i.e.
2020-12-07 16:49:35 +00:00
the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types
2020-10-10 16:31:00 +00:00
2020-12-07 16:49:35 +00:00
.. testcode ::
2020-10-10 16:31:00 +00:00
2020-12-07 16:49:35 +00:00
# Binary inputs
binary_preds = torch.tensor([0.6, 0.1, 0.9])
binary_target = torch.tensor([1, 0, 2])
2020-11-23 08:44:35 +00:00
2020-12-07 16:49:35 +00:00
# Multi-class inputs
mc_preds = torch.tensor([0, 2, 1])
mc_target = torch.tensor([0, 1, 2])
2020-11-23 08:44:35 +00:00
2020-12-07 16:49:35 +00:00
# Multi-class inputs with probabilities
mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]])
mc_target_probs = torch.tensor([0, 1, 2])
2020-10-10 16:31:00 +00:00
2020-12-07 16:49:35 +00:00
# Multi-label inputs
ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]])
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])
2020-10-30 10:44:25 +00:00
2020-12-30 19:49:50 +00:00
2021-01-18 08:24:13 +00:00
Using the is_multiclass parameter
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2020-12-30 19:49:50 +00:00
In some cases, you might have inputs which appear to be (multi-dimensional) multi-class
but are actually binary/multi-label - for example, if both predictions and targets are
integer (binary) tensors. Or it could be the other way around, you want to treat
binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs.
2020-10-30 10:44:25 +00:00
2020-12-07 16:49:35 +00:00
For these cases, the metrics where this distinction would make a difference, expose the
2020-12-30 19:49:50 +00:00
`` is_multiclass `` argument. Let's see how this is used on the example of
2021-02-22 08:50:59 +00:00
:class: `~pytorch_lightning.metrics.StatScores` metric.
2020-12-30 19:49:50 +00:00
First, let's consider the case with label predictions with 2 classes, which we want to
treat as binary.
.. testcode ::
from pytorch_lightning.metrics.functional import stat_scores
# These inputs are supposed to be binary, but appear as multi-class
preds = torch.tensor([0, 1, 0])
target = torch.tensor([1, 1, 0])
As you can see below, by default the inputs are treated
as multi-class. We can set `` is_multiclass=False `` to treat the inputs as binary -
which is the same as converting the predictions to float beforehand.
.. doctest ::
>>> stat_scores(preds, target, reduce='macro', num_classes=2)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=1, is_multiclass=False)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds.float(), target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
Next, consider the opposite example: inputs are binary (as predictions are probabilities),
but we would like to treat them as 2-class multi-class, to obtain the metric for both classes.
.. testcode ::
preds = torch.tensor([0.2, 0.7, 0.3])
target = torch.tensor([1, 1, 0])
In this case we can set `` is_multiclass=True `` , to treat the inputs as multi-class.
.. doctest ::
>>> stat_scores(preds, target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=2, is_multiclass=True)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])
2020-12-04 21:42:23 +00:00
2020-12-07 16:49:35 +00:00
Class Metrics (Classification)
------------------------------
Accuracy
~~~~~~~~
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.Accuracy
2020-12-04 21:42:23 +00:00
:noindex:
AveragePrecision
~~~~~~~~~~~~~~~~
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.AveragePrecision
2020-12-04 21:42:23 +00:00
:noindex:
2021-01-27 13:16:54 +00:00
AUC
~~~
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.AUC
2021-01-27 13:16:54 +00:00
:noindex:
AUROC
~~~~~
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.AUROC
2021-01-27 13:16:54 +00:00
:noindex:
2020-12-07 16:49:35 +00:00
ConfusionMatrix
~~~~~~~~~~~~~~~
2020-12-04 21:42:23 +00:00
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.ConfusionMatrix
2020-12-04 21:42:23 +00:00
:noindex:
2020-12-07 16:49:35 +00:00
F1
~~
2020-08-05 09:32:53 +00:00
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.F1
2020-08-05 09:32:53 +00:00
:noindex:
2020-09-01 18:59:33 +00:00
2020-12-07 16:49:35 +00:00
FBeta
~~~~~
2020-09-01 18:59:33 +00:00
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.FBeta
2020-08-05 09:32:53 +00:00
:noindex:
2021-01-08 13:36:08 +00:00
IoU
~~~
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.IoU
2021-01-08 13:36:08 +00:00
:noindex:
2020-08-05 09:32:53 +00:00
2020-12-21 15:42:51 +00:00
Hamming Distance
~~~~~~~~~~~~~~~~
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.HammingDistance
2020-12-21 15:42:51 +00:00
:noindex:
2020-12-07 16:49:35 +00:00
Precision
~~~~~~~~~
2020-09-01 18:59:33 +00:00
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.Precision
2020-08-05 09:32:53 +00:00
:noindex:
2020-10-07 19:12:15 +00:00
2020-12-07 16:49:35 +00:00
PrecisionRecallCurve
~~~~~~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.PrecisionRecallCurve
2020-10-09 02:58:33 +00:00
:noindex:
2020-12-07 16:49:35 +00:00
Recall
~~~~~~
2020-10-21 22:05:59 +00:00
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.Recall
2020-10-21 22:05:59 +00:00
:noindex:
2020-12-07 16:49:35 +00:00
ROC
~~~
2020-10-21 22:05:59 +00:00
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.ROC
2020-10-21 22:05:59 +00:00
:noindex:
2020-10-07 19:12:15 +00:00
2020-12-30 19:49:50 +00:00
StatScores
~~~~~~~~~~
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.StatScores
2020-12-30 19:49:50 +00:00
:noindex:
2020-12-07 16:49:35 +00:00
Functional Metrics (Classification)
-----------------------------------
2020-10-07 19:12:15 +00:00
accuracy [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2020-12-21 15:42:51 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.accuracy
2020-10-07 19:12:15 +00:00
:noindex:
2021-01-27 13:16:54 +00:00
2020-10-07 19:12:15 +00:00
auc [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2021-01-27 13:16:54 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.auc
2020-10-07 19:12:15 +00:00
:noindex:
auroc [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2021-01-27 13:16:54 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.auroc
2020-10-30 18:56:13 +00:00
:noindex:
2020-10-07 19:12:15 +00:00
average_precision [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2020-12-04 21:42:23 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.average_precision
2020-10-07 19:12:15 +00:00
:noindex:
confusion_matrix [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2020-10-30 10:44:25 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.confusion_matrix
2020-10-07 19:12:15 +00:00
:noindex:
dice_score [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2021-02-22 08:50:59 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.dice_score
2020-10-07 19:12:15 +00:00
:noindex:
2020-11-23 08:44:35 +00:00
f1 [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2020-11-23 08:44:35 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.f1
2020-10-07 19:12:15 +00:00
:noindex:
2020-11-23 08:44:35 +00:00
fbeta [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2020-11-23 08:44:35 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.fbeta
2020-10-07 19:12:15 +00:00
:noindex:
2020-12-21 15:42:51 +00:00
hamming_distance [func]
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction :: pytorch_lightning.metrics.functional.hamming_distance
:noindex:
2020-10-07 19:12:15 +00:00
iou [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2021-01-08 13:36:08 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.iou
2020-10-07 19:12:15 +00:00
:noindex:
2020-12-04 21:42:23 +00:00
roc [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2020-12-04 21:42:23 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.roc
2020-10-07 19:12:15 +00:00
:noindex:
precision [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2021-01-18 08:24:13 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.precision
2020-10-07 19:12:15 +00:00
:noindex:
precision_recall [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2021-01-18 08:24:13 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.precision_recall
2020-10-07 19:12:15 +00:00
:noindex:
precision_recall_curve [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2020-12-04 21:42:23 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.precision_recall_curve
2020-10-07 19:12:15 +00:00
:noindex:
recall [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2021-01-18 08:24:13 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.recall
2020-10-07 19:12:15 +00:00
:noindex:
2020-12-07 16:49:35 +00:00
select_topk [func]
~~~~~~~~~~~~~~~~~~~~~
.. autofunction :: pytorch_lightning.metrics.utils.select_topk
:noindex:
2020-10-07 19:12:15 +00:00
stat_scores [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2020-12-30 19:49:50 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.stat_scores
2020-10-07 19:12:15 +00:00
:noindex:
stat_scores_multiple_classes [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2021-02-22 08:50:59 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.stat_scores_multiple_classes
2020-10-07 19:12:15 +00:00
:noindex:
to_categorical [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2020-12-04 21:42:23 +00:00
.. autofunction :: pytorch_lightning.metrics.utils.to_categorical
2020-10-07 19:12:15 +00:00
:noindex:
to_onehot [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2020-12-04 21:42:23 +00:00
.. autofunction :: pytorch_lightning.metrics.utils.to_onehot
2020-10-07 19:12:15 +00:00
:noindex:
2020-12-07 16:49:35 +00:00
***** ***** ***** ***
Regression Metrics
***** ***** ***** ***
Class Metrics (Regression)
--------------------------
2020-10-07 19:12:15 +00:00
2020-12-07 16:49:35 +00:00
ExplainedVariance
~~~~~~~~~~~~~~~~~
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.ExplainedVariance
2020-12-07 16:49:35 +00:00
:noindex:
MeanAbsoluteError
~~~~~~~~~~~~~~~~~
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.MeanAbsoluteError
2020-12-07 16:49:35 +00:00
:noindex:
MeanSquaredError
~~~~~~~~~~~~~~~~
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.MeanSquaredError
2020-12-07 16:49:35 +00:00
:noindex:
MeanSquaredLogError
~~~~~~~~~~~~~~~~~~~
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.MeanSquaredLogError
2020-12-07 16:49:35 +00:00
:noindex:
PSNR
~~~~
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.PSNR
2020-12-07 16:49:35 +00:00
:noindex:
SSIM
~~~~
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.SSIM
2020-12-07 16:49:35 +00:00
:noindex:
2021-01-01 11:23:19 +00:00
R2Score
~~~~~~~
2021-02-22 08:50:59 +00:00
.. autoclass :: pytorch_lightning.metrics.R2Score
2021-01-01 11:23:19 +00:00
:noindex:
2020-12-07 16:49:35 +00:00
Functional Metrics (Regression)
-------------------------------
2020-10-07 19:12:15 +00:00
2020-10-21 22:05:59 +00:00
explained_variance [func]
~~~~~~~~~~~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2020-10-21 22:05:59 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.explained_variance
2020-10-07 19:12:15 +00:00
:noindex:
2021-01-07 19:34:38 +00:00
image_gradients [func]
~~~~~~~~~~~~~~~~~~~~~~
.. autofunction :: pytorch_lightning.metrics.functional.image_gradients
:noindex:
2020-10-21 22:05:59 +00:00
mean_absolute_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2020-10-21 22:05:59 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.mean_absolute_error
2020-10-07 19:12:15 +00:00
:noindex:
2020-10-21 22:05:59 +00:00
mean_squared_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2020-10-21 22:05:59 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.mean_squared_error
2020-10-07 19:12:15 +00:00
:noindex:
2020-12-07 16:49:35 +00:00
mean_squared_log_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2020-12-07 16:49:35 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.mean_squared_log_error
2020-10-07 19:12:15 +00:00
:noindex:
2020-12-07 16:49:35 +00:00
psnr [func]
~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2020-12-07 16:49:35 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.psnr
2020-10-07 19:12:15 +00:00
:noindex:
ssim [func]
2020-10-12 22:13:58 +00:00
~~~~~~~~~~~
2020-10-07 19:12:15 +00:00
2020-10-21 22:05:59 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.ssim
2020-10-07 19:12:15 +00:00
:noindex:
2021-01-07 19:34:38 +00:00
2021-01-01 11:23:19 +00:00
r2score [func]
~~~~~~~~~~~~~~
.. autofunction :: pytorch_lightning.metrics.functional.r2score
:noindex:
2020-12-07 16:49:35 +00:00
***
2020-10-07 19:12:15 +00:00
NLP
2020-12-07 16:49:35 +00:00
***
2020-10-07 19:12:15 +00:00
bleu_score [func]
2020-12-07 16:49:35 +00:00
-----------------
2020-10-07 19:12:15 +00:00
2021-02-22 08:50:59 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.bleu_score
2020-10-07 19:12:15 +00:00
:noindex:
2021-03-15 11:18:43 +00:00
***** ***** ***** ***** ***** *** *
Information Retrieval Metrics
***** ***** ***** ***** ***** *** *
Class Metrics (IR)
------------------
Mean Average Precision
~~~~~~~~~~~~~~~~~~~~~~
.. autoclass :: pytorch_lightning.metrics.retrieval.RetrievalMAP
:noindex:
Functional Metrics (IR)
-----------------------
average_precision_retrieval [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction :: pytorch_lightning.metrics.functional.ir_average_precision.retrieval_average_precision
:noindex:
2020-12-07 16:49:35 +00:00
***** ***
2020-10-12 22:13:58 +00:00
Pairwise
2020-12-07 16:49:35 +00:00
***** ***
2020-10-09 01:26:39 +00:00
embedding_similarity [func]
2020-12-07 16:49:35 +00:00
---------------------------
2020-10-09 01:26:39 +00:00
2021-02-22 08:50:59 +00:00
.. autofunction :: pytorch_lightning.metrics.functional.embedding_similarity
2020-10-09 01:26:39 +00:00
:noindex: