prune metrics: info retrieval (#6649)
This commit is contained in:
parent
36d180e532
commit
a74909affa
|
@ -9,8 +9,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
- Added `RetrievalMAP` metric, the corresponding functional version `retrieval_average_precision` and a generic superclass for retrieval metrics `RetrievalMetric` ([#5032](https://github.com/PyTorchLightning/pytorch-lightning/pull/5032))
|
||||
|
||||
|
||||
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))
|
||||
|
||||
|
|
|
@ -39,7 +39,6 @@ from pytorch_lightning.metrics.regression import ( # noqa: F401
|
|||
R2Score,
|
||||
SSIM,
|
||||
)
|
||||
from pytorch_lightning.metrics.retrieval import RetrievalMAP # noqa: F401
|
||||
|
||||
warn(
|
||||
"`pytorch_lightning.metrics.*` module has been renamed to `torchmetrics.*` and split off to its own package"
|
||||
|
|
|
@ -28,7 +28,6 @@ from pytorch_lightning.metrics.functional.f_beta import f1, fbeta # noqa: F401
|
|||
from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.image_gradients import image_gradients # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.iou import iou # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error # noqa: F401
|
||||
from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error # noqa: F401
|
||||
|
|
|
@ -1,54 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
|
||||
def retrieval_average_precision(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Computes average precision (for information retrieval), as explained
|
||||
`here <https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision>`_.
|
||||
|
||||
`preds` and `target` should be of the same shape and live on the same device. If no `target` is ``True``,
|
||||
0 is returned. Target must be of type `bool` or `int`, otherwise an error is raised.
|
||||
|
||||
Args:
|
||||
preds: estimated probabilities of each document to be relevant.
|
||||
target: ground truth about each document being relevant or not. Requires `bool` or `int` tensor.
|
||||
|
||||
Return:
|
||||
a single-value tensor with the average precision (AP) of the predictions `preds` wrt the labels `target`.
|
||||
|
||||
Example:
|
||||
>>> preds = torch.tensor([0.2, 0.3, 0.5])
|
||||
>>> target = torch.tensor([True, False, True])
|
||||
>>> retrieval_average_precision(preds, target)
|
||||
tensor(0.8333)
|
||||
"""
|
||||
|
||||
if preds.shape != target.shape or preds.device != target.device:
|
||||
raise ValueError("`preds` and `target` must have the same shape and live on the same device")
|
||||
|
||||
if target.dtype not in (torch.bool, torch.int16, torch.int32, torch.int64):
|
||||
raise ValueError("`target` must be a tensor of booleans or integers")
|
||||
|
||||
if target.dtype is not torch.bool:
|
||||
target = target.bool()
|
||||
|
||||
if target.sum() == 0:
|
||||
return torch.tensor(0, device=preds.device)
|
||||
|
||||
target = target[torch.argsort(preds, dim=-1, descending=True)]
|
||||
positions = torch.arange(1, len(target) + 1, device=target.device, dtype=torch.float32)[target > 0]
|
||||
res = torch.div((torch.arange(len(positions), device=positions.device, dtype=torch.float32) + 1), positions).mean()
|
||||
return res
|
|
@ -1,15 +0,0 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.metrics.retrieval.mean_average_precision import RetrievalMAP # noqa: F401
|
||||
from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric # noqa: F401
|
|
@ -1,61 +0,0 @@
|
|||
import torch
|
||||
|
||||
from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision
|
||||
from pytorch_lightning.metrics.retrieval.retrieval_metric import RetrievalMetric
|
||||
|
||||
|
||||
class RetrievalMAP(RetrievalMetric):
|
||||
r"""
|
||||
Computes `Mean Average Precision
|
||||
<https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Mean_average_precision>`_.
|
||||
|
||||
Works with binary data. Accepts integer or float predictions from a model output.
|
||||
|
||||
Forward accepts
|
||||
- ``indexes`` (long tensor): ``(N, ...)``
|
||||
- ``preds`` (float tensor): ``(N, ...)``
|
||||
- ``target`` (long or bool tensor): ``(N, ...)``
|
||||
|
||||
`indexes`, `preds` and `target` must have the same dimension.
|
||||
`indexes` indicate to which query a prediction belongs.
|
||||
Predictions will be first grouped by indexes and then MAP will be computed as the mean
|
||||
of the Average Precisions over each query.
|
||||
|
||||
Args:
|
||||
query_without_relevant_docs:
|
||||
Specify what to do with queries that do not have at least a positive target. Choose from:
|
||||
|
||||
- ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned
|
||||
- ``'error'``: raise a ``ValueError``
|
||||
- ``'pos'``: score on those queries is counted as ``1.0``
|
||||
- ``'neg'``: score on those queries is counted as ``0.0``
|
||||
exclude:
|
||||
Do not take into account predictions where the target is equal to this value. default `-100`
|
||||
compute_on_step:
|
||||
Forward only calls ``update()`` and return None if this is set to False. default: True
|
||||
dist_sync_on_step:
|
||||
Synchronize metric state across processes at each ``forward()``
|
||||
before returning the value at the step. default: False
|
||||
process_group:
|
||||
Specify the process group on which synchronization is called. default: None (which selects
|
||||
the entire world)
|
||||
dist_sync_fn:
|
||||
Callback that performs the allgather operation on the metric state. When `None`, DDP
|
||||
will be used to perform the allgather. default: None
|
||||
|
||||
Example:
|
||||
>>> from pytorch_lightning.metrics import RetrievalMAP
|
||||
>>> indexes = torch.tensor([0, 0, 0, 1, 1, 1, 1])
|
||||
>>> preds = torch.tensor([0.2, 0.3, 0.5, 0.1, 0.3, 0.5, 0.2])
|
||||
>>> target = torch.tensor([False, False, True, False, True, False, False])
|
||||
|
||||
>>> map = RetrievalMAP()
|
||||
>>> map(indexes, preds, target)
|
||||
tensor(0.7500)
|
||||
>>> map.compute()
|
||||
tensor(0.7500)
|
||||
"""
|
||||
|
||||
def _metric(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
valid_indexes = target != self.exclude
|
||||
return retrieval_average_precision(preds[valid_indexes], target[valid_indexes])
|
|
@ -1,140 +0,0 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torchmetrics import Metric
|
||||
|
||||
from pytorch_lightning.metrics.utils import get_group_indexes
|
||||
|
||||
#: get_group_indexes is used to group predictions belonging to the same query
|
||||
|
||||
IGNORE_IDX = -100
|
||||
|
||||
|
||||
class RetrievalMetric(Metric, ABC):
|
||||
r"""
|
||||
Works with binary data. Accepts integer or float predictions from a model output.
|
||||
|
||||
Forward accepts
|
||||
- ``indexes`` (long tensor): ``(N, ...)``
|
||||
- ``preds`` (float or int tensor): ``(N, ...)``
|
||||
- ``target`` (long or bool tensor): ``(N, ...)``
|
||||
|
||||
`indexes`, `preds` and `target` must have the same dimension and will be flatten
|
||||
to single dimension once provided.
|
||||
|
||||
`indexes` indicate to which query a prediction belongs.
|
||||
Predictions will be first grouped by indexes. Then the
|
||||
real metric, defined by overriding the `_metric` method,
|
||||
will be computed as the mean of the scores over each query.
|
||||
|
||||
Args:
|
||||
query_without_relevant_docs:
|
||||
Specify what to do with queries that do not have at least a positive target. Choose from:
|
||||
|
||||
- ``'skip'``: skip those queries (default); if all queries are skipped, ``0.0`` is returned
|
||||
- ``'error'``: raise a ``ValueError``
|
||||
- ``'pos'``: score on those queries is counted as ``1.0``
|
||||
- ``'neg'``: score on those queries is counted as ``0.0``
|
||||
exclude:
|
||||
Do not take into account predictions where the target is equal to this value. default `-100`
|
||||
compute_on_step:
|
||||
Forward only calls ``update()`` and return None if this is set to False. default: True
|
||||
dist_sync_on_step:
|
||||
Synchronize metric state across processes at each ``forward()``
|
||||
before returning the value at the step. default: False
|
||||
process_group:
|
||||
Specify the process group on which synchronization is called. default: None (which selects
|
||||
the entire world)
|
||||
dist_sync_fn:
|
||||
Callback that performs the allgather operation on the metric state. When `None`, DDP
|
||||
will be used to perform the allgather. default: None
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_without_relevant_docs: str = 'skip',
|
||||
exclude: int = IGNORE_IDX,
|
||||
compute_on_step: bool = True,
|
||||
dist_sync_on_step: bool = False,
|
||||
process_group: Optional[Any] = None,
|
||||
dist_sync_fn: Callable = None
|
||||
):
|
||||
super().__init__(
|
||||
compute_on_step=compute_on_step,
|
||||
dist_sync_on_step=dist_sync_on_step,
|
||||
process_group=process_group,
|
||||
dist_sync_fn=dist_sync_fn
|
||||
)
|
||||
|
||||
query_without_relevant_docs_options = ('error', 'skip', 'pos', 'neg')
|
||||
if query_without_relevant_docs not in query_without_relevant_docs_options:
|
||||
raise ValueError(
|
||||
f"`query_without_relevant_docs` received a wrong value {query_without_relevant_docs}. "
|
||||
f"Allowed values are {query_without_relevant_docs_options}"
|
||||
)
|
||||
|
||||
self.query_without_relevant_docs = query_without_relevant_docs
|
||||
self.exclude = exclude
|
||||
|
||||
self.add_state("idx", default=[], dist_reduce_fx=None)
|
||||
self.add_state("preds", default=[], dist_reduce_fx=None)
|
||||
self.add_state("target", default=[], dist_reduce_fx=None)
|
||||
|
||||
def update(self, idx: torch.Tensor, preds: torch.Tensor, target: torch.Tensor) -> None:
|
||||
if not (idx.shape == target.shape == preds.shape):
|
||||
raise ValueError("`idx`, `preds` and `target` must be of the same shape")
|
||||
|
||||
idx = idx.to(dtype=torch.int64).flatten()
|
||||
preds = preds.to(dtype=torch.float32).flatten()
|
||||
target = target.to(dtype=torch.int64).flatten()
|
||||
|
||||
self.idx.append(idx)
|
||||
self.preds.append(preds)
|
||||
self.target.append(target)
|
||||
|
||||
def compute(self) -> torch.Tensor:
|
||||
r"""
|
||||
First concat state `idx`, `preds` and `target` since they were stored as lists. After that,
|
||||
compute list of groups that will help in keeping together predictions about the same query.
|
||||
Finally, for each group compute the `_metric` if the number of positive targets is at least
|
||||
1, otherwise behave as specified by `self.query_without_relevant_docs`.
|
||||
"""
|
||||
|
||||
idx = torch.cat(self.idx, dim=0)
|
||||
preds = torch.cat(self.preds, dim=0)
|
||||
target = torch.cat(self.target, dim=0)
|
||||
|
||||
res = []
|
||||
kwargs = {'device': idx.device, 'dtype': torch.float32}
|
||||
|
||||
groups = get_group_indexes(idx)
|
||||
for group in groups:
|
||||
|
||||
mini_preds = preds[group]
|
||||
mini_target = target[group]
|
||||
|
||||
if not mini_target.sum():
|
||||
if self.query_without_relevant_docs == 'error':
|
||||
raise ValueError(
|
||||
f"`{self.__class__.__name__}.compute()` was provided with "
|
||||
f"a query without positive targets, indexes: {group}"
|
||||
)
|
||||
if self.query_without_relevant_docs == 'pos':
|
||||
res.append(torch.tensor(1.0, **kwargs))
|
||||
elif self.query_without_relevant_docs == 'neg':
|
||||
res.append(torch.tensor(0.0, **kwargs))
|
||||
else:
|
||||
res.append(self._metric(mini_preds, mini_target))
|
||||
|
||||
if len(res) > 0:
|
||||
return torch.stack(res).mean()
|
||||
return torch.tensor(0.0, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _metric(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Compute a metric over a predictions and target of a single group.
|
||||
This method should be overridden by subclasses.
|
||||
"""
|
|
@ -1,36 +0,0 @@
|
|||
import math
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import average_precision_score as sk_average_precision
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.metrics.functional.ir_average_precision import retrieval_average_precision
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
|
||||
pytest.param(sk_average_precision, retrieval_average_precision),
|
||||
])
|
||||
def test_against_sklearn(sklearn_metric, torch_metric):
|
||||
"""Compare PL metrics to sklearn version. """
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
seed_everything(0)
|
||||
|
||||
rounds = 25
|
||||
sizes = [1, 4, 10, 100]
|
||||
|
||||
for size in sizes:
|
||||
for _ in range(rounds):
|
||||
a = np.random.randn(size)
|
||||
b = np.random.randn(size) > 0
|
||||
|
||||
sk = torch.tensor(sklearn_metric(b, a), device=device)
|
||||
pl = torch_metric(torch.tensor(a, device=device), torch.tensor(b, device=device))
|
||||
|
||||
# `torch_metric`s return 0 when no label is True
|
||||
# while `sklearn.average_precision_score` returns NaN
|
||||
if math.isnan(sk):
|
||||
assert pl == 0
|
||||
else:
|
||||
assert torch.allclose(sk.float(), pl.float())
|
|
@ -1,119 +0,0 @@
|
|||
import math
|
||||
import random
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from sklearn.metrics import average_precision_score as sk_average_precision
|
||||
from torchmetrics import Metric
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.metrics.retrieval.mean_average_precision import RetrievalMAP
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['sklearn_metric', 'torch_class_metric'], [
|
||||
[sk_average_precision, RetrievalMAP],
|
||||
])
|
||||
def test_against_sklearn(sklearn_metric: Callable, torch_class_metric: Metric) -> None:
|
||||
"""Compare PL metrics to sklearn version. """
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
seed_everything(0)
|
||||
|
||||
rounds = 20
|
||||
sizes = [1, 4, 10, 100]
|
||||
batch_sizes = [1, 4, 10]
|
||||
query_without_relevant_docs_options = ['skip', 'pos', 'neg']
|
||||
|
||||
def compute_sklearn_metric(target: List[np.ndarray], preds: List[np.ndarray], behaviour: str) -> torch.Tensor:
|
||||
""" Compute sk metric with multiple iterations using the base `sklearn_metric`. """
|
||||
sk_results = []
|
||||
kwargs = {'device': device, 'dtype': torch.float32}
|
||||
|
||||
for b, a in zip(target, preds):
|
||||
res = sklearn_metric(b, a)
|
||||
|
||||
if math.isnan(res):
|
||||
if behaviour == 'skip':
|
||||
pass
|
||||
elif behaviour == 'pos':
|
||||
sk_results.append(torch.tensor(1.0, **kwargs))
|
||||
else:
|
||||
sk_results.append(torch.tensor(0.0, **kwargs))
|
||||
else:
|
||||
sk_results.append(torch.tensor(res, **kwargs))
|
||||
if len(sk_results) > 0:
|
||||
sk_results = torch.stack(sk_results).mean()
|
||||
else:
|
||||
sk_results = torch.tensor(0.0, **kwargs)
|
||||
|
||||
return sk_results
|
||||
|
||||
def do_test(batch_size: int, size: int) -> None:
|
||||
""" For each possible behaviour of the metric, check results are correct. """
|
||||
for behaviour in query_without_relevant_docs_options:
|
||||
metric = torch_class_metric(query_without_relevant_docs=behaviour)
|
||||
shape = (size, )
|
||||
|
||||
indexes = []
|
||||
preds = []
|
||||
target = []
|
||||
|
||||
for i in range(batch_size):
|
||||
indexes.append(np.ones(shape, dtype=int) * i)
|
||||
preds.append(np.random.randn(*shape))
|
||||
target.append(np.random.randn(*shape) > 0)
|
||||
|
||||
sk_results = compute_sklearn_metric(target, preds, behaviour)
|
||||
|
||||
indexes_tensor = torch.cat([torch.tensor(i) for i in indexes])
|
||||
preds_tensor = torch.cat([torch.tensor(p) for p in preds])
|
||||
target_tensor = torch.cat([torch.tensor(t) for t in target])
|
||||
|
||||
# lets assume data are not ordered
|
||||
perm = torch.randperm(indexes_tensor.nelement())
|
||||
indexes_tensor = indexes_tensor.view(-1)[perm].view(indexes_tensor.size())
|
||||
preds_tensor = preds_tensor.view(-1)[perm].view(preds_tensor.size())
|
||||
target_tensor = target_tensor.view(-1)[perm].view(target_tensor.size())
|
||||
|
||||
# shuffle ids to require also sorting of documents ability from the lightning metric
|
||||
pl_result = metric(indexes_tensor, preds_tensor, target_tensor)
|
||||
|
||||
assert torch.allclose(sk_results.float(), pl_result.float(), equal_nan=True)
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for size in sizes:
|
||||
for _ in range(rounds):
|
||||
do_test(batch_size, size)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['torch_class_metric'], [
|
||||
[RetrievalMAP],
|
||||
])
|
||||
def test_input_data(torch_class_metric: Metric) -> None:
|
||||
"""Check PL metrics inputs are controlled correctly. """
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
seed_everything(0)
|
||||
|
||||
for _ in range(10):
|
||||
|
||||
length = random.randint(0, 20)
|
||||
|
||||
# check error when `query_without_relevant_docs='error'` is raised correctly
|
||||
indexes = torch.tensor([0] * length, device=device, dtype=torch.int64)
|
||||
preds = torch.rand(size=(length, ), device=device, dtype=torch.float32)
|
||||
target = torch.tensor([False] * length, device=device, dtype=torch.bool)
|
||||
|
||||
metric = torch_class_metric(query_without_relevant_docs='error')
|
||||
|
||||
try:
|
||||
metric(indexes, preds, target)
|
||||
except Exception as e:
|
||||
assert isinstance(e, ValueError)
|
||||
|
||||
# check ValueError with non-accepted argument
|
||||
try:
|
||||
metric = torch_class_metric(query_without_relevant_docs='casual_argument')
|
||||
except Exception as e:
|
||||
assert isinstance(e, ValueError)
|
Loading…
Reference in New Issue