64 lines
1.7 KiB
Python
64 lines
1.7 KiB
Python
|
import torch
|
||
|
from typing import Any, Callable, Optional, Union
|
||
|
|
||
|
from pytorch_lightning.metrics.metric import Metric
|
||
|
|
||
|
|
||
|
class ExplainedVariance(Metric):
|
||
|
"""
|
||
|
Computes explained variance.
|
||
|
|
||
|
Example:
|
||
|
|
||
|
>>> from pytorch_lightning.metrics import ExplainedVariance
|
||
|
>>> target = torch.tensor([3, -0.5, 2, 7])
|
||
|
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
|
||
|
>>> explained_variance = ExplainedVariance()
|
||
|
>>> explained_variance(preds, target)
|
||
|
tensor(0.9572)
|
||
|
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
compute_on_step: bool = True,
|
||
|
ddp_sync_on_step: bool = False,
|
||
|
process_group: Optional[Any] = None,
|
||
|
):
|
||
|
super().__init__(
|
||
|
compute_on_step=compute_on_step,
|
||
|
ddp_sync_on_step=ddp_sync_on_step,
|
||
|
process_group=process_group,
|
||
|
)
|
||
|
|
||
|
self.add_state("y", default=[], dist_reduce_fx=None)
|
||
|
self.add_state("y_pred", default=[], dist_reduce_fx=None)
|
||
|
|
||
|
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
||
|
"""
|
||
|
Update state with predictions and targets.
|
||
|
|
||
|
Args:
|
||
|
preds: Predictions from model
|
||
|
target: Ground truth values
|
||
|
"""
|
||
|
self.y.append(target)
|
||
|
self.y_pred.append(preds)
|
||
|
|
||
|
def compute(self):
|
||
|
"""
|
||
|
Computes explained variance over state.
|
||
|
"""
|
||
|
y_true = torch.cat(self.y, dim=0)
|
||
|
y_pred = torch.cat(self.y_pred, dim=0)
|
||
|
|
||
|
y_diff_avg = torch.mean(y_true - y_pred, dim=0)
|
||
|
numerator = torch.mean((y_true - y_pred - y_diff_avg) ** 2, dim=0)
|
||
|
|
||
|
y_true_avg = torch.mean(y_true, dim=0)
|
||
|
denominator = torch.mean((y_true - y_true_avg) ** 2, dim=0)
|
||
|
|
||
|
# TODO: multioutput
|
||
|
return 1.0 - torch.mean(numerator / denominator)
|