diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 817d0b7074..bca93416e7 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -181,6 +181,13 @@ MeanSquaredLogError :noindex: +ExplainedVariance +^^^^^^^^^^^^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance + :noindex: + + Functional Metrics ================== diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 615a08e27a..6a20c6a0b1 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -1,4 +1,9 @@ from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.metrics.classification.accuracy import Accuracy -from pytorch_lightning.metrics.regression import MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError +from pytorch_lightning.metrics.regression import ( + MeanSquaredError, + MeanAbsoluteError, + MeanSquaredLogError, + ExplainedVariance, +) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index acd2b2d5e2..34c8ef88a9 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -96,7 +96,11 @@ class Metric(nn.Module, ABC): the format discussed in the above note. """ - if not isinstance(default, torch.Tensor) or (isinstance(default, list) and len(default) != 0): + if ( + not isinstance(default, torch.Tensor) + and not isinstance(default, list) # noqa: W503 + or (isinstance(default, list) and len(default) != 0) # noqa: W503 + ): raise ValueError( "state variable must be a tensor or any empty list (where you can append tensors)" ) @@ -163,7 +167,7 @@ class Metric(nn.Module, ABC): elif isinstance(output_dict[attr][0], list): output_dict[attr] = _flatten(output_dict[attr]) - assert isinstance(reduction_fn, (Callable, None)) + assert isinstance(reduction_fn, (Callable)) or reduction_fn is None reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr] setattr(self, attr, reduced) diff --git a/pytorch_lightning/metrics/regression/__init__.py b/pytorch_lightning/metrics/regression/__init__.py index c5f235aeff..a7893e9c26 100644 --- a/pytorch_lightning/metrics/regression/__init__.py +++ b/pytorch_lightning/metrics/regression/__init__.py @@ -1,3 +1,4 @@ from pytorch_lightning.metrics.regression.mean_squared_error import MeanSquaredError from pytorch_lightning.metrics.regression.mean_absolute_error import MeanAbsoluteError from pytorch_lightning.metrics.regression.mean_squared_log_error import MeanSquaredLogError +from pytorch_lightning.metrics.regression.explained_variance import ExplainedVariance diff --git a/pytorch_lightning/metrics/regression/explained_variance.py b/pytorch_lightning/metrics/regression/explained_variance.py new file mode 100644 index 0000000000..e12ab8d8a5 --- /dev/null +++ b/pytorch_lightning/metrics/regression/explained_variance.py @@ -0,0 +1,63 @@ +import torch +from typing import Any, Callable, Optional, Union + +from pytorch_lightning.metrics.metric import Metric + + +class ExplainedVariance(Metric): + """ + Computes explained variance. + + Example: + + >>> from pytorch_lightning.metrics import ExplainedVariance + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> explained_variance = ExplainedVariance() + >>> explained_variance(preds, target) + tensor(0.9572) + + + """ + + def __init__( + self, + compute_on_step: bool = True, + ddp_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + ddp_sync_on_step=ddp_sync_on_step, + process_group=process_group, + ) + + self.add_state("y", default=[], dist_reduce_fx=None) + self.add_state("y_pred", default=[], dist_reduce_fx=None) + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + self.y.append(target) + self.y_pred.append(preds) + + def compute(self): + """ + Computes explained variance over state. + """ + y_true = torch.cat(self.y, dim=0) + y_pred = torch.cat(self.y_pred, dim=0) + + y_diff_avg = torch.mean(y_true - y_pred, dim=0) + numerator = torch.mean((y_true - y_pred - y_diff_avg) ** 2, dim=0) + + y_true_avg = torch.mean(y_true, dim=0) + denominator = torch.mean((y_true - y_true_avg) ** 2, dim=0) + + # TODO: multioutput + return 1.0 - torch.mean(numerator / denominator) diff --git a/tests/metrics/regression/test_explained_variance.py b/tests/metrics/regression/test_explained_variance.py new file mode 100644 index 0000000000..7c54d486ef --- /dev/null +++ b/tests/metrics/regression/test_explained_variance.py @@ -0,0 +1,51 @@ +import torch +import pytest +from collections import namedtuple +from functools import partial + +from pytorch_lightning.metrics.regression import ExplainedVariance +from sklearn.metrics import explained_variance_score + +from tests.metrics.utils import compute_batch, setup_ddp +from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE + +torch.manual_seed(42) + +num_targets = 5 + +Input = namedtuple('Input', ["preds", "target"]) + +_single_target_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.rand(NUM_BATCHES, BATCH_SIZE), +) + +_multi_target_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), +) + + +def _single_target_sk_metric(preds, target, sk_fn=explained_variance_score): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + return sk_fn(sk_target, sk_preds) + + +def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score): + sk_preds = preds.view(-1, num_targets).numpy() + sk_target = target.view(-1, num_targets).numpy() + return sk_fn(sk_target, sk_preds) + + +@pytest.mark.parametrize("ddp", [True, False]) +@pytest.mark.parametrize("ddp_sync_on_step", [True, False]) +@pytest.mark.parametrize( + "preds, target, sk_metric", + [ + (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric), + (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric), + ], +) +def test_explained_variance(ddp, ddp_sync_on_step, preds, target, sk_metric): + compute_batch(preds, target, ExplainedVariance, sk_metric, ddp_sync_on_step, ddp)