From a74909affa0535da02e64b94f6d5f9b2da03c08f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 23 Mar 2021 16:05:32 +0100 Subject: [PATCH] prune metrics: info retrieval (#6649) --- CHANGELOG.md | 2 - pytorch_lightning/metrics/__init__.py | 1 - .../metrics/functional/__init__.py | 1 - .../functional/ir_average_precision.py | 54 ------- .../metrics/retrieval/__init__.py | 15 -- .../retrieval/mean_average_precision.py | 61 -------- .../metrics/retrieval/retrieval_metric.py | 140 ------------------ tests/metrics/functional/test_retrieval.py | 36 ----- tests/metrics/retrieval/__init__.py | 0 tests/metrics/retrieval/test_map.py | 119 --------------- 10 files changed, 429 deletions(-) delete mode 100644 pytorch_lightning/metrics/functional/ir_average_precision.py delete mode 100644 pytorch_lightning/metrics/retrieval/__init__.py delete mode 100644 pytorch_lightning/metrics/retrieval/mean_average_precision.py delete mode 100644 pytorch_lightning/metrics/retrieval/retrieval_metric.py delete mode 100644 tests/metrics/functional/test_retrieval.py delete mode 100644 tests/metrics/retrieval/__init__.py delete mode 100644 tests/metrics/retrieval/test_map.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 32cf9122ef..81bfa85cc0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,8 +9,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added `RetrievalMAP` metric, the corresponding functional version `retrieval_average_precision` and a generic superclass for retrieval metrics `RetrievalMetric` ([#5032](https://github.com/PyTorchLightning/pytorch-lightning/pull/5032)) - - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 500689f318..1da24737a3 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -39,7 +39,6 @@ from pytorch_lightning.metrics.regression import ( # noqa: F401 R2Score, SSIM, ) -from pytorch_lightning.metrics.retrieval import RetrievalMAP # noqa: F401 warn( "`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package" diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 1701389cd1..3b31dad5d3 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -28,7 +28,6 @@ from pytorch_lightning.metrics.functional.f_beta import f1, fbeta # noqa: F401 from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401 from pytorch_lightning.metrics.functional.image_gradients import image_gradients # noqa: F401 from pytorch_lightning.metrics.functional.iou import iou # noqa: F401 -from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision # noqa: F401 from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error # noqa: F401 from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error # noqa: F401 from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error # noqa: F401 diff --git a/pytorch_lightning/metrics/functional/ir_average_precision.py b/pytorch_lightning/metrics/functional/ir_average_precision.py deleted file mode 100644 index 83b14a21c5..0000000000 --- a/pytorch_lightning/metrics/functional/ir_average_precision.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch - - -def retrieval_average_precision(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - r""" - Computes average precision (for information retrieval), as explained - `here `_. - - `preds` and `target` should be of the same shape and live on the same device. If no `target` is ``True``, - 0 is returned. Target must be of type `bool` or `int`, otherwise an error is raised. - - Args: - preds: estimated probabilities of each document to be relevant. - target: ground truth about each document being relevant or not. Requires `bool` or `int` tensor. - - Return: - a single-value tensor with the average precision (AP) of the predictions `preds` wrt the labels `target`. - - Example: - >>> preds = torch.tensor([0.2, 0.3, 0.5]) - >>> target = torch.tensor([True, False, True]) - >>> retrieval_average_precision(preds, target) - tensor(0.8333) - """ - - if preds.shape != target.shape or preds.device != target.device: - raise ValueError("`preds` and `target` must have the same shape and live on the same device") - - if target.dtype not in (torch.bool, torch.int16, torch.int32, torch.int64): - raise ValueError("`target` must be a tensor of booleans or integers") - - if target.dtype is not torch.bool: - target = target.bool() - - if target.sum() == 0: - return torch.tensor(0, device=preds.device) - - target = target[torch.argsort(preds, dim=-1, descending=True)] - positions = torch.arange(1, len(target) + 1, device=target.device, dtype=torch.float32)[target > 0] - res = torch.div((torch.arange(len(positions), device=positions.device, dtype=torch.float32) + 1), positions).mean() - return res diff --git a/pytorch_lightning/metrics/retrieval/__init__.py b/pytorch_lightning/metrics/retrieval/__init__.py deleted file mode 100644 index c5c12b3b66..0000000000 --- a/pytorch_lightning/metrics/retrieval/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from pytorch_lightning.metrics.retrieval.mean_average_precision import RetrievalMAP # noqa: F401 -from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401 diff --git a/pytorch_lightning/metrics/retrieval/mean_average_precision.py b/pytorch_lightning/metrics/retrieval/mean_average_precision.py deleted file mode 100644 index 956a53cca2..0000000000 --- a/pytorch_lightning/metrics/retrieval/mean_average_precision.py +++ /dev/null @@ -1,61 +0,0 @@ -import torch - -from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision -from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric - - -class RetrievalMAP(RetrievalMetric): - r""" - Computes `Mean Average Precision - `_. - - Works with binary data. Accepts integer or float predictions from a model output. - - Forward accepts - - ``indexes`` (long tensor): ``(N, ...)`` - - ``preds`` (float tensor): ``(N, ...)`` - - ``target`` (long or bool tensor): ``(N, ...)`` - - `indexes`, `preds` and `target` must have the same dimension. - `indexes` indicate to which query a prediction belongs. - Predictions will be first grouped by indexes and then MAP will be computed as the mean - of the Average Precisions over each query. - - Args: - query_without_relevant_docs: - Specify what to do with queries that do not have at least a positive target. Choose from: - - - ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned - - ``'error'``: raise a ``ValueError`` - - ``'pos'``: score on those queries is counted as ``1.0`` - - ``'neg'``: score on those queries is counted as ``0.0`` - exclude: - Do not take into account predictions where the target is equal to this value. default `-100` - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects - the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP - will be used to perform the allgather. default: None - - Example: - >>> from pytorch_lightning.metrics import RetrievalMAP - >>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1]) - >>> preds = torch.tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2]) - >>> target = torch.tensor([False, False, True, False, True, False, False]) - - >>> map = RetrievalMAP() - >>> map(indexes, preds, target) - tensor(0.7500) - >>> map.compute() - tensor(0.7500) - """ - - def _metric(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - valid_indexes = target != self.exclude - return retrieval_average_precision(preds[valid_indexes], target[valid_indexes]) diff --git a/pytorch_lightning/metrics/retrieval/retrieval_metric.py b/pytorch_lightning/metrics/retrieval/retrieval_metric.py deleted file mode 100644 index 6f9088d000..0000000000 --- a/pytorch_lightning/metrics/retrieval/retrieval_metric.py +++ /dev/null @@ -1,140 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Callable, Optional - -import torch -from torchmetrics import Metric - -from pytorch_lightning.metrics.utils import get_group_indexes - -#: get_group_indexes is used to group predictions belonging to the same query - -IGNORE_IDX = -100 - - -class RetrievalMetric(Metric, ABC): - r""" - Works with binary data. Accepts integer or float predictions from a model output. - - Forward accepts - - ``indexes`` (long tensor): ``(N, ...)`` - - ``preds`` (float or int tensor): ``(N, ...)`` - - ``target`` (long or bool tensor): ``(N, ...)`` - - `indexes`, `preds` and `target` must have the same dimension and will be flatten - to single dimension once provided. - - `indexes` indicate to which query a prediction belongs. - Predictions will be first grouped by indexes. Then the - real metric, defined by overriding the `_metric` method, - will be computed as the mean of the scores over each query. - - Args: - query_without_relevant_docs: - Specify what to do with queries that do not have at least a positive target. Choose from: - - - ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned - - ``'error'``: raise a ``ValueError`` - - ``'pos'``: score on those queries is counted as ``1.0`` - - ``'neg'``: score on those queries is counted as ``0.0`` - exclude: - Do not take into account predictions where the target is equal to this value. default `-100` - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False - process_group: - Specify the process group on which synchronization is called. default: None (which selects - the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP - will be used to perform the allgather. default: None - - """ - - def __init__( - self, - query_without_relevant_docs: str = 'skip', - exclude: int = IGNORE_IDX, - compute_on_step: bool = True, - dist_sync_on_step: bool = False, - process_group: Optional[Any] = None, - dist_sync_fn: Callable = None - ): - super().__init__( - compute_on_step=compute_on_step, - dist_sync_on_step=dist_sync_on_step, - process_group=process_group, - dist_sync_fn=dist_sync_fn - ) - - query_without_relevant_docs_options = ('error', 'skip', 'pos', 'neg') - if query_without_relevant_docs not in query_without_relevant_docs_options: - raise ValueError( - f"`query_without_relevant_docs` received a wrong value {query_without_relevant_docs}. " - f"Allowed values are {query_without_relevant_docs_options}" - ) - - self.query_without_relevant_docs = query_without_relevant_docs - self.exclude = exclude - - self.add_state("idx", default=[], dist_reduce_fx=None) - self.add_state("preds", default=[], dist_reduce_fx=None) - self.add_state("target", default=[], dist_reduce_fx=None) - - def update(self, idx: torch.Tensor, preds: torch.Tensor, target: torch.Tensor) -> None: - if not (idx.shape == target.shape == preds.shape): - raise ValueError("`idx`, `preds` and `target` must be of the same shape") - - idx = idx.to(dtype=torch.int64).flatten() - preds = preds.to(dtype=torch.float32).flatten() - target = target.to(dtype=torch.int64).flatten() - - self.idx.append(idx) - self.preds.append(preds) - self.target.append(target) - - def compute(self) -> torch.Tensor: - r""" - First concat state `idx`, `preds` and `target` since they were stored as lists. After that, - compute list of groups that will help in keeping together predictions about the same query. - Finally, for each group compute the `_metric` if the number of positive targets is at least - 1, otherwise behave as specified by `self.query_without_relevant_docs`. - """ - - idx = torch.cat(self.idx, dim=0) - preds = torch.cat(self.preds, dim=0) - target = torch.cat(self.target, dim=0) - - res = [] - kwargs = {'device': idx.device, 'dtype': torch.float32} - - groups = get_group_indexes(idx) - for group in groups: - - mini_preds = preds[group] - mini_target = target[group] - - if not mini_target.sum(): - if self.query_without_relevant_docs == 'error': - raise ValueError( - f"`{self.__class__.__name__}.compute()` was provided with " - f"a query without positive targets, indexes: {group}" - ) - if self.query_without_relevant_docs == 'pos': - res.append(torch.tensor(1.0, **kwargs)) - elif self.query_without_relevant_docs == 'neg': - res.append(torch.tensor(0.0, **kwargs)) - else: - res.append(self._metric(mini_preds, mini_target)) - - if len(res) > 0: - return torch.stack(res).mean() - return torch.tensor(0.0, **kwargs) - - @abstractmethod - def _metric(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - r""" - Compute a metric over a predictions and target of a single group. - This method should be overridden by subclasses. - """ diff --git a/tests/metrics/functional/test_retrieval.py b/tests/metrics/functional/test_retrieval.py deleted file mode 100644 index a0573cba1d..0000000000 --- a/tests/metrics/functional/test_retrieval.py +++ /dev/null @@ -1,36 +0,0 @@ -import math - -import numpy as np -import pytest -import torch -from sklearn.metrics import average_precision_score as sk_average_precision - -from pytorch_lightning import seed_everything -from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision - - -@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ - pytest.param(sk_average_precision, retrieval_average_precision), -]) -def test_against_sklearn(sklearn_metric, torch_metric): - """Compare PL metrics to sklearn version. """ - device = 'cuda' if torch.cuda.is_available() else 'cpu' - seed_everything(0) - - rounds = 25 - sizes = [1, 4, 10, 100] - - for size in sizes: - for _ in range(rounds): - a = np.random.randn(size) - b = np.random.randn(size) > 0 - - sk = torch.tensor(sklearn_metric(b, a), device=device) - pl = torch_metric(torch.tensor(a, device=device), torch.tensor(b, device=device)) - - # `torch_metric`s return 0 when no label is True - # while `sklearn.average_precision_score` returns NaN - if math.isnan(sk): - assert pl == 0 - else: - assert torch.allclose(sk.float(), pl.float()) diff --git a/tests/metrics/retrieval/__init__.py b/tests/metrics/retrieval/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/metrics/retrieval/test_map.py b/tests/metrics/retrieval/test_map.py deleted file mode 100644 index fe43f19b20..0000000000 --- a/tests/metrics/retrieval/test_map.py +++ /dev/null @@ -1,119 +0,0 @@ -import math -import random -from typing import Callable, List - -import numpy as np -import pytest -import torch -from sklearn.metrics import average_precision_score as sk_average_precision -from torchmetrics import Metric - -from pytorch_lightning import seed_everything -from pytorch_lightning.metrics.retrieval.mean_average_precision import RetrievalMAP - - -@pytest.mark.parametrize(['sklearn_metric', 'torch_class_metric'], [ - [sk_average_precision, RetrievalMAP], -]) -def test_against_sklearn(sklearn_metric: Callable, torch_class_metric: Metric) -> None: - """Compare PL metrics to sklearn version. """ - device = 'cuda' if torch.cuda.is_available() else 'cpu' - seed_everything(0) - - rounds = 20 - sizes = [1, 4, 10, 100] - batch_sizes = [1, 4, 10] - query_without_relevant_docs_options = ['skip', 'pos', 'neg'] - - def compute_sklearn_metric(target: List[np.ndarray], preds: List[np.ndarray], behaviour: str) -> torch.Tensor: - """ Compute sk metric with multiple iterations using the base `sklearn_metric`. """ - sk_results = [] - kwargs = {'device': device, 'dtype': torch.float32} - - for b, a in zip(target, preds): - res = sklearn_metric(b, a) - - if math.isnan(res): - if behaviour == 'skip': - pass - elif behaviour == 'pos': - sk_results.append(torch.tensor(1.0, **kwargs)) - else: - sk_results.append(torch.tensor(0.0, **kwargs)) - else: - sk_results.append(torch.tensor(res, **kwargs)) - if len(sk_results) > 0: - sk_results = torch.stack(sk_results).mean() - else: - sk_results = torch.tensor(0.0, **kwargs) - - return sk_results - - def do_test(batch_size: int, size: int) -> None: - """ For each possible behaviour of the metric, check results are correct. """ - for behaviour in query_without_relevant_docs_options: - metric = torch_class_metric(query_without_relevant_docs=behaviour) - shape = (size, ) - - indexes = [] - preds = [] - target = [] - - for i in range(batch_size): - indexes.append(np.ones(shape, dtype=int) * i) - preds.append(np.random.randn(*shape)) - target.append(np.random.randn(*shape) > 0) - - sk_results = compute_sklearn_metric(target, preds, behaviour) - - indexes_tensor = torch.cat([torch.tensor(i) for i in indexes]) - preds_tensor = torch.cat([torch.tensor(p) for p in preds]) - target_tensor = torch.cat([torch.tensor(t) for t in target]) - - # lets assume data are not ordered - perm = torch.randperm(indexes_tensor.nelement()) - indexes_tensor = indexes_tensor.view(-1)[perm].view(indexes_tensor.size()) - preds_tensor = preds_tensor.view(-1)[perm].view(preds_tensor.size()) - target_tensor = target_tensor.view(-1)[perm].view(target_tensor.size()) - - # shuffle ids to require also sorting of documents ability from the lightning metric - pl_result = metric(indexes_tensor, preds_tensor, target_tensor) - - assert torch.allclose(sk_results.float(), pl_result.float(), equal_nan=True) - - for batch_size in batch_sizes: - for size in sizes: - for _ in range(rounds): - do_test(batch_size, size) - - -@pytest.mark.parametrize(['torch_class_metric'], [ - [RetrievalMAP], -]) -def test_input_data(torch_class_metric: Metric) -> None: - """Check PL metrics inputs are controlled correctly. """ - - device = 'cuda' if torch.cuda.is_available() else 'cpu' - seed_everything(0) - - for _ in range(10): - - length = random.randint(0, 20) - - # check error when `query_without_relevant_docs='error'` is raised correctly - indexes = torch.tensor([0] * length, device=device, dtype=torch.int64) - preds = torch.rand(size=(length, ), device=device, dtype=torch.float32) - target = torch.tensor([False] * length, device=device, dtype=torch.bool) - - metric = torch_class_metric(query_without_relevant_docs='error') - - try: - metric(indexes, preds, target) - except Exception as e: - assert isinstance(e, ValueError) - - # check ValueError with non-accepted argument - try: - metric = torch_class_metric(query_without_relevant_docs='casual_argument') - except Exception as e: - assert isinstance(e, ValueError)