lightning/pytorch_lightning/metrics/regression/explained_variance.py

64 lines
1.7 KiB
Python
Raw Normal View History

import torch
from typing import Any, Callable, Optional, Union
from pytorch_lightning.metrics.metric import Metric
class ExplainedVariance(Metric):
"""
Computes explained variance.
Example:
>>> from pytorch_lightning.metrics import ExplainedVariance
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> explained_variance = ExplainedVariance()
>>> explained_variance(preds, target)
tensor(0.9572)
"""
def __init__(
self,
compute_on_step: bool = True,
ddp_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
compute_on_step=compute_on_step,
ddp_sync_on_step=ddp_sync_on_step,
process_group=process_group,
)
self.add_state("y", default=[], dist_reduce_fx=None)
self.add_state("y_pred", default=[], dist_reduce_fx=None)
def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.
Args:
preds: Predictions from model
target: Ground truth values
"""
self.y.append(target)
self.y_pred.append(preds)
def compute(self):
"""
Computes explained variance over state.
"""
y_true = torch.cat(self.y, dim=0)
y_pred = torch.cat(self.y_pred, dim=0)
y_diff_avg = torch.mean(y_true - y_pred, dim=0)
numerator = torch.mean((y_true - y_pred - y_diff_avg) ** 2, dim=0)
y_true_avg = torch.mean(y_true, dim=0)
denominator = torch.mean((y_true - y_true_avg) ** 2, dim=0)
# TODO: multioutput
return 1.0 - torch.mean(numerator / denominator)