fix handling of non symmetric scorers in pure python mode
This commit is contained in:
parent
383a91f8c4
commit
5946aeb1bc
|
@ -1,6 +1,12 @@
|
|||
Changelog
|
||||
---------
|
||||
|
||||
[2.14.0] -
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
Fixed
|
||||
~~~~~~~
|
||||
- fix handling of non symmetric scorers in pure python version of ``process.cdist``
|
||||
|
||||
[2.13.7] - 2022-12-20
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
Fixed
|
||||
|
|
|
@ -18,7 +18,7 @@ def _get_scorer_flags_distance(**_kwargs: Any) -> dict[str, Any]:
|
|||
return {
|
||||
"optimal_score": 0,
|
||||
"worst_score": 2**63 - 1,
|
||||
"flags": ScorerFlag.RESULT_I64,
|
||||
"flags": ScorerFlag.RESULT_I64 | ScorerFlag.SYMMETRIC,
|
||||
}
|
||||
|
||||
|
||||
|
@ -26,7 +26,23 @@ def _get_scorer_flags_similarity(**_kwargs: Any) -> dict[str, Any]:
|
|||
return {
|
||||
"optimal_score": 2**63 - 1,
|
||||
"worst_score": 0,
|
||||
"flags": ScorerFlag.RESULT_I64,
|
||||
"flags": ScorerFlag.RESULT_I64 | ScorerFlag.SYMMETRIC,
|
||||
}
|
||||
|
||||
|
||||
def _get_scorer_flags_normalized_distance(**_kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"optimal_score": 0,
|
||||
"worst_score": 1,
|
||||
"flags": ScorerFlag.RESULT_F64 | ScorerFlag.SYMMETRIC,
|
||||
}
|
||||
|
||||
|
||||
def _get_scorer_flags_normalized_similarity(**_kwargs: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"optimal_score": 1,
|
||||
"worst_score": 0,
|
||||
"flags": ScorerFlag.RESULT_F64 | ScorerFlag.SYMMETRIC,
|
||||
}
|
||||
|
||||
|
||||
|
@ -40,14 +56,6 @@ def is_none(s: Any) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _get_scorer_flags_normalized_distance(**_kwargs: Any) -> dict[str, Any]:
|
||||
return {"optimal_score": 0, "worst_score": 1, "flags": ScorerFlag.RESULT_F64}
|
||||
|
||||
|
||||
def _get_scorer_flags_normalized_similarity(**_kwargs: Any) -> dict[str, Any]:
|
||||
return {"optimal_score": 1, "worst_score": 0, "flags": ScorerFlag.RESULT_F64}
|
||||
|
||||
|
||||
def _create_scorer(
|
||||
func: Any, cached_scorer_call: dict[str, Callable[..., dict[str, Any]]]
|
||||
):
|
||||
|
|
|
@ -9,12 +9,71 @@ substitutions required to transform s1 into s2.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from rapidfuzz._utils import default_distance_attribute as _dist_attr
|
||||
from rapidfuzz._utils import default_normalized_distance_attribute as _norm_dist_attr
|
||||
from rapidfuzz._utils import default_normalized_similarity_attribute as _norm_sim_attr
|
||||
from rapidfuzz._utils import default_similarity_attribute as _sim_attr
|
||||
from rapidfuzz._utils import ScorerFlag as _ScorerFlag
|
||||
from rapidfuzz._utils import fallback_import as _fallback_import
|
||||
|
||||
|
||||
def _get_scorer_flags_distance(
|
||||
weights: tuple[int, int, int] | None = (1, 1, 1)
|
||||
) -> dict[str, Any]:
|
||||
flags = _ScorerFlag.RESULT_I64
|
||||
if weights is None or weights[0] == weights[1]:
|
||||
flags |= _ScorerFlag.SYMMETRIC
|
||||
|
||||
return {
|
||||
"optimal_score": 0,
|
||||
"worst_score": 2**63 - 1,
|
||||
"flags": flags,
|
||||
}
|
||||
|
||||
|
||||
def _get_scorer_flags_similarity(
|
||||
weights: tuple[int, int, int] | None = (1, 1, 1)
|
||||
) -> dict[str, Any]:
|
||||
flags = _ScorerFlag.RESULT_I64
|
||||
if weights is None or weights[0] == weights[1]:
|
||||
flags |= _ScorerFlag.SYMMETRIC
|
||||
|
||||
return {
|
||||
"optimal_score": 2**63 - 1,
|
||||
"worst_score": 0,
|
||||
"flags": flags,
|
||||
}
|
||||
|
||||
|
||||
def _get_scorer_flags_normalized_distance(
|
||||
weights: tuple[int, int, int] | None = (1, 1, 1)
|
||||
) -> dict[str, Any]:
|
||||
flags = _ScorerFlag.RESULT_F64
|
||||
if weights is None or weights[0] == weights[1]:
|
||||
flags |= _ScorerFlag.SYMMETRIC
|
||||
|
||||
return {"optimal_score": 0, "worst_score": 1, "flags": flags}
|
||||
|
||||
|
||||
def _get_scorer_flags_normalized_similarity(
|
||||
weights: tuple[int, int, int] | None = (1, 1, 1)
|
||||
) -> dict[str, Any]:
|
||||
flags = _ScorerFlag.RESULT_F64
|
||||
if weights is None or weights[0] == weights[1]:
|
||||
flags |= _ScorerFlag.SYMMETRIC
|
||||
|
||||
return {"optimal_score": 1, "worst_score": 0, "flags": flags}
|
||||
|
||||
|
||||
_dist_attr: dict[str, Callable[..., dict[str, Any]]] = {
|
||||
"get_scorer_flags": _get_scorer_flags_distance
|
||||
}
|
||||
_sim_attr: dict[str, Callable[..., dict[str, Any]]] = {
|
||||
"get_scorer_flags": _get_scorer_flags_similarity
|
||||
}
|
||||
_norm_dist_attr: dict[str, Callable[..., dict[str, Any]]] = {
|
||||
"get_scorer_flags": _get_scorer_flags_normalized_distance
|
||||
}
|
||||
_norm_sim_attr: dict[str, Callable[..., dict[str, Any]]] = {
|
||||
"get_scorer_flags": _get_scorer_flags_normalized_similarity
|
||||
}
|
||||
|
||||
_mod = "rapidfuzz.distance.Levenshtein"
|
||||
distance = _fallback_import(_mod, "distance", cached_scorer_call=_dist_attr)
|
||||
similarity = _fallback_import(_mod, "similarity", cached_scorer_call=_sim_attr)
|
||||
|
|
|
@ -10,7 +10,11 @@ from rapidfuzz._utils import fallback_import as _fallback_import
|
|||
|
||||
|
||||
def _get_scorer_flags_fuzz(**_kwargs: Any) -> dict[str, Any]:
|
||||
return {"optimal_score": 100, "worst_score": 0, "flags": _ScorerFlag.RESULT_F64}
|
||||
return {
|
||||
"optimal_score": 100,
|
||||
"worst_score": 0,
|
||||
"flags": _ScorerFlag.RESULT_F64 | _ScorerFlag.SYMMETRIC,
|
||||
}
|
||||
|
||||
|
||||
_fuzz_attribute: dict[str, Callable[..., dict[str, Any]]] = {
|
||||
|
|
|
@ -221,11 +221,6 @@ cdef bool GetScorerFlagsFuzzRatio(const RF_Kwargs* self, RF_ScorerFlags* scorer_
|
|||
scorer_flags.worst_score.f64 = 0
|
||||
return True
|
||||
|
||||
def _GetScorerFlagsSimilarity(**kwargs):
|
||||
return {"optimal_score": 100, "worst_score": 0, "flags": (1 << 5)}
|
||||
|
||||
cdef dict FuzzContextPy = CreateScorerContextPy(_GetScorerFlagsSimilarity)
|
||||
|
||||
cdef RF_Scorer RatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, RatioInit)
|
||||
ratio._RF_Scorer = PyCapsule_New(&RatioContext, NULL, NULL)
|
||||
|
||||
|
|
|
@ -1470,6 +1470,8 @@ cdef Matrix cdist_single_list(
|
|||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef cdist_py(queries, choices, scorer, processor, score_cutoff, dtype, workers, dict kwargs):
|
||||
# todo this should handle two similar sequences more efficiently
|
||||
|
||||
proc_queries = preprocess_py(queries, processor)
|
||||
proc_choices = preprocess_py(choices, processor)
|
||||
cdef double score
|
||||
|
|
|
@ -520,8 +520,10 @@ def extract(
|
|||
return heapq.nsmallest(limit, result_iter, key=lambda i: i[1])
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
import numpy as np
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def _dtype_to_type_num(
|
||||
|
@ -544,6 +546,16 @@ def _dtype_to_type_num(
|
|||
return np.float32
|
||||
|
||||
|
||||
def _is_symmetric(scorer: Callable[..., int | float], **kwargs: dict[str, Any]) -> bool:
|
||||
params = getattr(scorer, "_RF_ScorerPy", None)
|
||||
if params is not None:
|
||||
flags = params["get_scorer_flags"](**kwargs)
|
||||
if flags["flags"] & ScorerFlag.SYMMETRIC:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def cdist(
|
||||
queries: Collection[Sequence[Hashable] | None],
|
||||
choices: Collection[Sequence[Hashable] | None],
|
||||
|
@ -616,7 +628,7 @@ def cdist(
|
|||
dtype = _dtype_to_type_num(dtype, scorer, **kwargs)
|
||||
results = np.zeros((len(queries), len(choices)), dtype=dtype)
|
||||
|
||||
if queries is choices:
|
||||
if queries is choices and _is_symmetric(scorer, **kwargs):
|
||||
if processor is None:
|
||||
proc_queries = list(queries)
|
||||
else:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import pytest
|
||||
|
||||
from rapidfuzz import fuzz, process_cpp, process_py
|
||||
from rapidfuzz.distance import Levenshtein
|
||||
|
||||
|
||||
def wrapped(func):
|
||||
|
@ -39,12 +40,14 @@ class process:
|
|||
|
||||
@staticmethod
|
||||
def cdist(*args, **kwargs):
|
||||
import numpy as np
|
||||
|
||||
res1 = process_cpp.cdist(*args, **kwargs)
|
||||
res2 = process_py.cdist(*args, **kwargs)
|
||||
assert res1.dtype == res2.dtype
|
||||
assert res1.shape == res2.shape
|
||||
if res1.size and res2.size:
|
||||
assert res1 == res2
|
||||
assert np.array_equal(res1, res2)
|
||||
return res1
|
||||
|
||||
|
||||
|
@ -382,3 +385,15 @@ def test_wrapped_function(scorer):
|
|||
assert process.cdist(["test"], [float("nan")], scorer=scorer)[0, 0] == 100
|
||||
assert process.cdist(["test"], [None], scorer=scorer)[0, 0] == 100
|
||||
assert process.cdist(["test"], ["tes"], scorer=scorer)[0, 0] == 100
|
||||
|
||||
|
||||
def test_cdist_not_symmetric():
|
||||
pytest.importorskip("numpy")
|
||||
import numpy as np
|
||||
|
||||
strings = ["test", "test2"]
|
||||
expected_res = np.array([[0, 1], [2, 0]])
|
||||
assert np.array_equal(
|
||||
process.cdist(strings, strings, scorer=Levenshtein.distance, weights=(1, 2, 1)),
|
||||
expected_res,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue