Implement Explained Variance Metric + metric fix (#4013)
* metric fix, explained variance * one more test * pep8 * remove comment * fix add_state condition Co-authored-by: ananyahjha93 <ananya@pytorchlightning.ai>
This commit is contained in:
parent
7db26a913b
commit
b961e12f50
|
@ -181,6 +181,13 @@ MeanSquaredLogError
|
|||
:noindex:
|
||||
|
||||
|
||||
ExplainedVariance
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance
|
||||
:noindex:
|
||||
|
||||
|
||||
Functional Metrics
|
||||
==================
|
||||
|
||||
|
|
|
@ -1,4 +1,9 @@
|
|||
from pytorch_lightning.metrics.metric import Metric
|
||||
|
||||
from pytorch_lightning.metrics.classification.accuracy import Accuracy
|
||||
from pytorch_lightning.metrics.regression import MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError
|
||||
from pytorch_lightning.metrics.regression import (
|
||||
MeanSquaredError,
|
||||
MeanAbsoluteError,
|
||||
MeanSquaredLogError,
|
||||
ExplainedVariance,
|
||||
)
|
||||
|
|
|
@ -96,7 +96,11 @@ class Metric(nn.Module, ABC):
|
|||
the format discussed in the above note.
|
||||
|
||||
"""
|
||||
if not isinstance(default, torch.Tensor) or (isinstance(default, list) and len(default) != 0):
|
||||
if (
|
||||
not isinstance(default, torch.Tensor)
|
||||
and not isinstance(default, list) # noqa: W503
|
||||
or (isinstance(default, list) and len(default) != 0) # noqa: W503
|
||||
):
|
||||
raise ValueError(
|
||||
"state variable must be a tensor or any empty list (where you can append tensors)"
|
||||
)
|
||||
|
@ -163,7 +167,7 @@ class Metric(nn.Module, ABC):
|
|||
elif isinstance(output_dict[attr][0], list):
|
||||
output_dict[attr] = _flatten(output_dict[attr])
|
||||
|
||||
assert isinstance(reduction_fn, (Callable, None))
|
||||
assert isinstance(reduction_fn, (Callable)) or reduction_fn is None
|
||||
reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr]
|
||||
setattr(self, attr, reduced)
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from pytorch_lightning.metrics.regression.mean_squared_error import MeanSquaredError
|
||||
from pytorch_lightning.metrics.regression.mean_absolute_error import MeanAbsoluteError
|
||||
from pytorch_lightning.metrics.regression.mean_squared_log_error import MeanSquaredLogError
|
||||
from pytorch_lightning.metrics.regression.explained_variance import ExplainedVariance
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
import torch
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from pytorch_lightning.metrics.metric import Metric
|
||||
|
||||
|
||||
class ExplainedVariance(Metric):
|
||||
"""
|
||||
Computes explained variance.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from pytorch_lightning.metrics import ExplainedVariance
|
||||
>>> target = torch.tensor([3, -0.5, 2, 7])
|
||||
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
|
||||
>>> explained_variance = ExplainedVariance()
|
||||
>>> explained_variance(preds, target)
|
||||
tensor(0.9572)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compute_on_step: bool = True,
|
||||
ddp_sync_on_step: bool = False,
|
||||
process_group: Optional[Any] = None,
|
||||
):
|
||||
super().__init__(
|
||||
compute_on_step=compute_on_step,
|
||||
ddp_sync_on_step=ddp_sync_on_step,
|
||||
process_group=process_group,
|
||||
)
|
||||
|
||||
self.add_state("y", default=[], dist_reduce_fx=None)
|
||||
self.add_state("y_pred", default=[], dist_reduce_fx=None)
|
||||
|
||||
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||||
"""
|
||||
Update state with predictions and targets.
|
||||
|
||||
Args:
|
||||
preds: Predictions from model
|
||||
target: Ground truth values
|
||||
"""
|
||||
self.y.append(target)
|
||||
self.y_pred.append(preds)
|
||||
|
||||
def compute(self):
|
||||
"""
|
||||
Computes explained variance over state.
|
||||
"""
|
||||
y_true = torch.cat(self.y, dim=0)
|
||||
y_pred = torch.cat(self.y_pred, dim=0)
|
||||
|
||||
y_diff_avg = torch.mean(y_true - y_pred, dim=0)
|
||||
numerator = torch.mean((y_true - y_pred - y_diff_avg) ** 2, dim=0)
|
||||
|
||||
y_true_avg = torch.mean(y_true, dim=0)
|
||||
denominator = torch.mean((y_true - y_true_avg) ** 2, dim=0)
|
||||
|
||||
# TODO: multioutput
|
||||
return 1.0 - torch.mean(numerator / denominator)
|
|
@ -0,0 +1,51 @@
|
|||
import torch
|
||||
import pytest
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
|
||||
from pytorch_lightning.metrics.regression import ExplainedVariance
|
||||
from sklearn.metrics import explained_variance_score
|
||||
|
||||
from tests.metrics.utils import compute_batch, setup_ddp
|
||||
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
num_targets = 5
|
||||
|
||||
Input = namedtuple('Input', ["preds", "target"])
|
||||
|
||||
_single_target_inputs = Input(
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
|
||||
target=torch.rand(NUM_BATCHES, BATCH_SIZE),
|
||||
)
|
||||
|
||||
_multi_target_inputs = Input(
|
||||
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
|
||||
target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
|
||||
)
|
||||
|
||||
|
||||
def _single_target_sk_metric(preds, target, sk_fn=explained_variance_score):
|
||||
sk_preds = preds.view(-1).numpy()
|
||||
sk_target = target.view(-1).numpy()
|
||||
return sk_fn(sk_target, sk_preds)
|
||||
|
||||
|
||||
def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score):
|
||||
sk_preds = preds.view(-1, num_targets).numpy()
|
||||
sk_target = target.view(-1, num_targets).numpy()
|
||||
return sk_fn(sk_target, sk_preds)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ddp", [True, False])
|
||||
@pytest.mark.parametrize("ddp_sync_on_step", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"preds, target, sk_metric",
|
||||
[
|
||||
(_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric),
|
||||
(_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric),
|
||||
],
|
||||
)
|
||||
def test_explained_variance(ddp, ddp_sync_on_step, preds, target, sk_metric):
|
||||
compute_batch(preds, target, ExplainedVariance, sk_metric, ddp_sync_on_step, ddp)
|
Loading…
Reference in New Issue