prune metrics: info retrieval (#6649)

This commit is contained in:
Jirka Borovec 2021-03-23 16:05:32 +01:00 committed by GitHub
parent 36d180e532
commit a74909affa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 0 additions and 429 deletions

View File

@ -9,8 +9,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added `RetrievalMAP` metric, the corresponding functional version `retrieval_average_precision` and a generic superclass for retrieval metrics `RetrievalMetric` ([#5032](https://github.com/PyTorchLightning/pytorch-lightning/pull/5032))
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))

View File

@ -39,7 +39,6 @@ from pytorch_lightning.metrics.regression import ( # noqa: F401
R2Score,
SSIM,
)
from pytorch_lightning.metrics.retrieval import RetrievalMAP # noqa: F401
warn(
"`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package"

View File

@ -28,7 +28,6 @@ from pytorch_lightning.metrics.functional.f_beta import f1, fbeta # noqa: F401
from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401
from pytorch_lightning.metrics.functional.image_gradients import image_gradients # noqa: F401
from pytorch_lightning.metrics.functional.iou import iou # noqa: F401
from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision # noqa: F401
from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error # noqa: F401
from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error # noqa: F401
from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error # noqa: F401

View File

@ -1,54 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def retrieval_average_precision(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
r"""
Computes average precision (for information retrieval), as explained
`here <https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision>`_.
`preds` and `target` should be of the same shape and live on the same device. If no `target` is ``True``,
0 is returned. Target must be of type `bool` or `int`, otherwise an error is raised.
Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant or not. Requires `bool` or `int` tensor.
Return:
a single-value tensor with the average precision (AP) of the predictions `preds` wrt the labels `target`.
Example:
>>> preds = torch.tensor([0.2, 0.3, 0.5])
>>> target = torch.tensor([True, False, True])
>>> retrieval_average_precision(preds, target)
tensor(0.8333)
"""
if preds.shape != target.shape or preds.device != target.device:
raise ValueError("`preds` and `target` must have the same shape and live on the same device")
if target.dtype not in (torch.bool, torch.int16, torch.int32, torch.int64):
raise ValueError("`target` must be a tensor of booleans or integers")
if target.dtype is not torch.bool:
target = target.bool()
if target.sum() == 0:
return torch.tensor(0, device=preds.device)
target = target[torch.argsort(preds, dim=-1, descending=True)]
positions = torch.arange(1, len(target) + 1, device=target.device, dtype=torch.float32)[target > 0]
res = torch.div((torch.arange(len(positions), device=positions.device, dtype=torch.float32) + 1), positions).mean()
return res

View File

@ -1,15 +0,0 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.metrics.retrieval.mean_average_precision import RetrievalMAP # noqa: F401
from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401

View File

@ -1,61 +0,0 @@
import torch
from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision
from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric
class RetrievalMAP(RetrievalMetric):
r"""
Computes `Mean Average Precision
<https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Mean_average_precision>`_.
Works with binary data. Accepts integer or float predictions from a model output.
Forward accepts
- ``indexes`` (long tensor): ``(N, ...)``
- ``preds`` (float tensor): ``(N, ...)``
- ``target`` (long or bool tensor): ``(N, ...)``
`indexes`, `preds` and `target` must have the same dimension.
`indexes` indicate to which query a prediction belongs.
Predictions will be first grouped by indexes and then MAP will be computed as the mean
of the Average Precisions over each query.
Args:
query_without_relevant_docs:
Specify what to do with queries that do not have at least a positive target. Choose from:
- ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``
- ``'pos'``: score on those queries is counted as ``1.0``
- ``'neg'``: score on those queries is counted as ``0.0``
exclude:
Do not take into account predictions where the target is equal to this value. default `-100`
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects
the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
Example:
>>> from pytorch_lightning.metrics import RetrievalMAP
>>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1])
>>> preds = torch.tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
>>> target = torch.tensor([False, False, True, False, True, False, False])
>>> map = RetrievalMAP()
>>> map(indexes, preds, target)
tensor(0.7500)
>>> map.compute()
tensor(0.7500)
"""
def _metric(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
valid_indexes = target != self.exclude
return retrieval_average_precision(preds[valid_indexes], target[valid_indexes])

View File

@ -1,140 +0,0 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional
import torch
from torchmetrics import Metric
from pytorch_lightning.metrics.utils import get_group_indexes
#: get_group_indexes is used to group predictions belonging to the same query
IGNORE_IDX = -100
class RetrievalMetric(Metric, ABC):
r"""
Works with binary data. Accepts integer or float predictions from a model output.
Forward accepts
- ``indexes`` (long tensor): ``(N, ...)``
- ``preds`` (float or int tensor): ``(N, ...)``
- ``target`` (long or bool tensor): ``(N, ...)``
`indexes`, `preds` and `target` must have the same dimension and will be flatten
to single dimension once provided.
`indexes` indicate to which query a prediction belongs.
Predictions will be first grouped by indexes. Then the
real metric, defined by overriding the `_metric` method,
will be computed as the mean of the scores over each query.
Args:
query_without_relevant_docs:
Specify what to do with queries that do not have at least a positive target. Choose from:
- ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned
- ``'error'``: raise a ``ValueError``
- ``'pos'``: score on those queries is counted as ``1.0``
- ``'neg'``: score on those queries is counted as ``0.0``
exclude:
Do not take into account predictions where the target is equal to this value. default `-100`
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects
the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
"""
def __init__(
self,
query_without_relevant_docs: str = 'skip',
exclude: int = IGNORE_IDX,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn
)
query_without_relevant_docs_options = ('error', 'skip', 'pos', 'neg')
if query_without_relevant_docs not in query_without_relevant_docs_options:
raise ValueError(
f"`query_without_relevant_docs` received a wrong value {query_without_relevant_docs}. "
f"Allowed values are {query_without_relevant_docs_options}"
)
self.query_without_relevant_docs = query_without_relevant_docs
self.exclude = exclude
self.add_state("idx", default=[], dist_reduce_fx=None)
self.add_state("preds", default=[], dist_reduce_fx=None)
self.add_state("target", default=[], dist_reduce_fx=None)
def update(self, idx: torch.Tensor, preds: torch.Tensor, target: torch.Tensor) -> None:
if not (idx.shape == target.shape == preds.shape):
raise ValueError("`idx`, `preds` and `target` must be of the same shape")
idx = idx.to(dtype=torch.int64).flatten()
preds = preds.to(dtype=torch.float32).flatten()
target = target.to(dtype=torch.int64).flatten()
self.idx.append(idx)
self.preds.append(preds)
self.target.append(target)
def compute(self) -> torch.Tensor:
r"""
First concat state `idx`, `preds` and `target` since they were stored as lists. After that,
compute list of groups that will help in keeping together predictions about the same query.
Finally, for each group compute the `_metric` if the number of positive targets is at least
1, otherwise behave as specified by `self.query_without_relevant_docs`.
"""
idx = torch.cat(self.idx, dim=0)
preds = torch.cat(self.preds, dim=0)
target = torch.cat(self.target, dim=0)
res = []
kwargs = {'device': idx.device, 'dtype': torch.float32}
groups = get_group_indexes(idx)
for group in groups:
mini_preds = preds[group]
mini_target = target[group]
if not mini_target.sum():
if self.query_without_relevant_docs == 'error':
raise ValueError(
f"`{self.__class__.__name__}.compute()` was provided with "
f"a query without positive targets, indexes: {group}"
)
if self.query_without_relevant_docs == 'pos':
res.append(torch.tensor(1.0, **kwargs))
elif self.query_without_relevant_docs == 'neg':
res.append(torch.tensor(0.0, **kwargs))
else:
res.append(self._metric(mini_preds, mini_target))
if len(res) > 0:
return torch.stack(res).mean()
return torch.tensor(0.0, **kwargs)
@abstractmethod
def _metric(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
r"""
Compute a metric over a predictions and target of a single group.
This method should be overridden by subclasses.
"""

View File

@ -1,36 +0,0 @@
import math
import numpy as np
import pytest
import torch
from sklearn.metrics import average_precision_score as sk_average_precision
from pytorch_lightning import seed_everything
from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
pytest.param(sk_average_precision, retrieval_average_precision),
])
def test_against_sklearn(sklearn_metric, torch_metric):
"""Compare PL metrics to sklearn version. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'
seed_everything(0)
rounds = 25
sizes = [1, 4, 10, 100]
for size in sizes:
for _ in range(rounds):
a = np.random.randn(size)
b = np.random.randn(size) > 0
sk = torch.tensor(sklearn_metric(b, a), device=device)
pl = torch_metric(torch.tensor(a, device=device), torch.tensor(b, device=device))
# `torch_metric`s return 0 when no label is True
# while `sklearn.average_precision_score` returns NaN
if math.isnan(sk):
assert pl == 0
else:
assert torch.allclose(sk.float(), pl.float())

View File

@ -1,119 +0,0 @@
import math
import random
from typing import Callable, List
import numpy as np
import pytest
import torch
from sklearn.metrics import average_precision_score as sk_average_precision
from torchmetrics import Metric
from pytorch_lightning import seed_everything
from pytorch_lightning.metrics.retrieval.mean_average_precision import RetrievalMAP
@pytest.mark.parametrize(['sklearn_metric', 'torch_class_metric'], [
[sk_average_precision, RetrievalMAP],
])
def test_against_sklearn(sklearn_metric: Callable, torch_class_metric: Metric) -> None:
"""Compare PL metrics to sklearn version. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'
seed_everything(0)
rounds = 20
sizes = [1, 4, 10, 100]
batch_sizes = [1, 4, 10]
query_without_relevant_docs_options = ['skip', 'pos', 'neg']
def compute_sklearn_metric(target: List[np.ndarray], preds: List[np.ndarray], behaviour: str) -> torch.Tensor:
""" Compute sk metric with multiple iterations using the base `sklearn_metric`. """
sk_results = []
kwargs = {'device': device, 'dtype': torch.float32}
for b, a in zip(target, preds):
res = sklearn_metric(b, a)
if math.isnan(res):
if behaviour == 'skip':
pass
elif behaviour == 'pos':
sk_results.append(torch.tensor(1.0, **kwargs))
else:
sk_results.append(torch.tensor(0.0, **kwargs))
else:
sk_results.append(torch.tensor(res, **kwargs))
if len(sk_results) > 0:
sk_results = torch.stack(sk_results).mean()
else:
sk_results = torch.tensor(0.0, **kwargs)
return sk_results
def do_test(batch_size: int, size: int) -> None:
""" For each possible behaviour of the metric, check results are correct. """
for behaviour in query_without_relevant_docs_options:
metric = torch_class_metric(query_without_relevant_docs=behaviour)
shape = (size, )
indexes = []
preds = []
target = []
for i in range(batch_size):
indexes.append(np.ones(shape, dtype=int) * i)
preds.append(np.random.randn(*shape))
target.append(np.random.randn(*shape) > 0)
sk_results = compute_sklearn_metric(target, preds, behaviour)
indexes_tensor = torch.cat([torch.tensor(i) for i in indexes])
preds_tensor = torch.cat([torch.tensor(p) for p in preds])
target_tensor = torch.cat([torch.tensor(t) for t in target])
# lets assume data are not ordered
perm = torch.randperm(indexes_tensor.nelement())
indexes_tensor = indexes_tensor.view(-1)[perm].view(indexes_tensor.size())
preds_tensor = preds_tensor.view(-1)[perm].view(preds_tensor.size())
target_tensor = target_tensor.view(-1)[perm].view(target_tensor.size())
# shuffle ids to require also sorting of documents ability from the lightning metric
pl_result = metric(indexes_tensor, preds_tensor, target_tensor)
assert torch.allclose(sk_results.float(), pl_result.float(), equal_nan=True)
for batch_size in batch_sizes:
for size in sizes:
for _ in range(rounds):
do_test(batch_size, size)
@pytest.mark.parametrize(['torch_class_metric'], [
[RetrievalMAP],
])
def test_input_data(torch_class_metric: Metric) -> None:
"""Check PL metrics inputs are controlled correctly. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'
seed_everything(0)
for _ in range(10):
length = random.randint(0, 20)
# check error when `query_without_relevant_docs='error'` is raised correctly
indexes = torch.tensor([0] * length, device=device, dtype=torch.int64)
preds = torch.rand(size=(length, ), device=device, dtype=torch.float32)
target = torch.tensor([False] * length, device=device, dtype=torch.bool)
metric = torch_class_metric(query_without_relevant_docs='error')
try:
metric(indexes, preds, target)
except Exception as e:
assert isinstance(e, ValueError)
# check ValueError with non-accepted argument
try:
metric = torch_class_metric(query_without_relevant_docs='casual_argument')
except Exception as e:
assert isinstance(e, ValueError)