lightning/tests/metrics/regression/test_explained_variance.py

73 lines
2.6 KiB
Python
Raw Normal View History

from collections import namedtuple
from functools import partial
import pytest
import torch
from sklearn.metrics import explained_variance_score
from pytorch_lightning.metrics.regression import ExplainedVariance
from pytorch_lightning.metrics.functional import explained_variance
from tests.metrics.utils import BATCH_SIZE, NUM_BATCHES, MetricTester
torch.manual_seed(42)
num_targets = 5
Input = namedtuple('Input', ["preds", "target"])
_single_target_inputs = Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE),)
_multi_target_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
)
def _single_target_sk_metric(preds, target, sk_fn=explained_variance_score):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_fn(sk_target, sk_preds)
def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score):
sk_preds = preds.view(-1, num_targets).numpy()
sk_target = target.view(-1, num_targets).numpy()
return sk_fn(sk_target, sk_preds)
@pytest.mark.parametrize("multioutput", ['raw_values', 'uniform_average', 'variance_weighted'])
@pytest.mark.parametrize(
"preds, target, sk_metric",
[
(_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric),
(_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric),
],
)
class TestExplainedVariance(MetricTester):
@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_explained_variance(self, multioutput, preds, target, sk_metric, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp,
preds,
target,
ExplainedVariance,
partial(sk_metric, sk_fn=partial(explained_variance_score, multioutput=multioutput)),
dist_sync_on_step,
metric_args=dict(multioutput=multioutput),
)
def test_explained_variance_functional(self, multioutput, preds, target, sk_metric):
self.run_functional_metric_test(
preds,
target,
explained_variance,
partial(sk_metric, sk_fn=partial(explained_variance_score, multioutput=multioutput)),
metric_args=dict(multioutput=multioutput),
)
def test_error_on_different_shape(metric_class=ExplainedVariance):
metric = metric_class()
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'):
metric(torch.randn(100,), torch.randn(50,))