lightning/tests/metrics/classification/test_accuracy.py

114 lines
3.7 KiB
Python

import numpy as np
import pytest
import torch
from sklearn.metrics import accuracy_score
from pytorch_lightning.metrics.classification.accuracy import Accuracy
from tests.metrics.classification.inputs import (
_binary_inputs,
_binary_prob_inputs,
_multiclass_inputs,
_multiclass_prob_inputs,
_multidim_multiclass_inputs,
_multidim_multiclass_prob_inputs,
_multilabel_inputs,
_multilabel_prob_inputs,
)
from tests.metrics.utils import THRESHOLD, MetricTester
torch.manual_seed(42)
def _sk_accuracy_binary_prob(preds, target):
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1).numpy()
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
def _sk_accuracy_binary(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
def _sk_accuracy_multilabel_prob(preds, target):
sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8)
sk_target = target.view(-1).numpy()
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
def _sk_accuracy_multilabel(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
def _sk_accuracy_multiclass_prob(preds, target):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy()
sk_target = target.view(-1).numpy()
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
def _sk_accuracy_multiclass(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
def _sk_accuracy_multidim_multiclass_prob(preds, target):
sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy()
sk_target = target.view(-1).numpy()
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
def _sk_accuracy_multidim_multiclass(preds, target):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return accuracy_score(y_true=sk_target, y_pred=sk_preds)
def test_accuracy_invalid_shape():
with pytest.raises(ValueError):
acc = Accuracy()
acc.update(preds=torch.rand(1), target=torch.rand(1, 2, 3))
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
@pytest.mark.parametrize(
"preds, target, sk_metric",
[
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_accuracy_binary_prob),
(_binary_inputs.preds, _binary_inputs.target, _sk_accuracy_binary),
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_accuracy_multilabel_prob),
(_multilabel_inputs.preds, _multilabel_inputs.target, _sk_accuracy_multilabel),
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_accuracy_multiclass_prob),
(_multiclass_inputs.preds, _multiclass_inputs.target, _sk_accuracy_multiclass),
(
_multidim_multiclass_prob_inputs.preds,
_multidim_multiclass_prob_inputs.target,
_sk_accuracy_multidim_multiclass_prob,
),
(_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, _sk_accuracy_multidim_multiclass),
],
)
class TestAccuracy(MetricTester):
def test_accuracy(self, ddp, dist_sync_on_step, preds, target, sk_metric):
self.run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=Accuracy,
sk_metric=sk_metric,
dist_sync_on_step=dist_sync_on_step,
metric_args={"threshold": THRESHOLD},
)