Compare commits
2 Commits
e6bc56c1ad
...
3aeef594a1
Author | SHA1 | Date |
---|---|---|
Max Bachmann | 3aeef594a1 | |
Max Bachmann | 9693b7da76 |
|
@ -1,5 +1,9 @@
|
|||
## Changelog
|
||||
|
||||
### [2.14.0] -
|
||||
#### Fixed
|
||||
- improve handling of functions wrapped using `functools.wraps`
|
||||
|
||||
### [2.13.2] - 2022-11-05
|
||||
#### Fixed
|
||||
- fix incorrect results in `Hamming.normalized_similarity`
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 97cb88437af19ebb095d46cf545cbce23ec2d083
|
||||
Subproject commit 749d32ad560d5d9c9917dec61f5d28c2b0923a78
|
|
@ -36,6 +36,11 @@ def _get_scorer_flags_normalized_distance(**_kwargs: Any) -> dict[str, Any]:
|
|||
def _get_scorer_flags_normalized_similarity(**_kwargs: Any) -> dict[str, Any]:
|
||||
return {"optimal_score": 1, "worst_score": 0, "flags": ScorerFlag.RESULT_F64}
|
||||
|
||||
def _create_scorer(func: Any, cached_scorer_call: dict[str, Callable[..., dict[str, Any]]]):
|
||||
func._RF_ScorerPy = cached_scorer_call
|
||||
# used to detect the function hasn't been wrapped afterwards
|
||||
func._RF_OriginalScorer = func
|
||||
return func
|
||||
|
||||
def fallback_import(
|
||||
module: str,
|
||||
|
@ -60,7 +65,7 @@ def fallback_import(
|
|||
)
|
||||
|
||||
if cached_scorer_call:
|
||||
py_func._RF_ScorerPy = cached_scorer_call
|
||||
py_func = _create_scorer(py_func, cached_scorer_call)
|
||||
|
||||
if impl == "cpp":
|
||||
cpp_mod = importlib.import_module(module + "_cpp")
|
||||
|
@ -84,11 +89,10 @@ def fallback_import(
|
|||
cpp_func.__doc__ = py_func.__doc__
|
||||
|
||||
if cached_scorer_call:
|
||||
cpp_func._RF_ScorerPy = cached_scorer_call
|
||||
cpp_func = _create_scorer(cpp_func, cached_scorer_call)
|
||||
|
||||
return cpp_func
|
||||
|
||||
|
||||
default_distance_attribute: dict[str, Callable[..., dict[str, Any]]] = {
|
||||
"get_scorer_flags": _get_scorer_flags_distance
|
||||
}
|
||||
|
|
|
@ -409,8 +409,3 @@ cdef inline RF_Scorer CreateScorerContext(RF_KwargsInit kwargs_init, RF_GetScore
|
|||
|
||||
cdef inline dict CreateScorerContextPy(get_scorer_flags):
|
||||
return {"get_scorer_flags": get_scorer_flags}
|
||||
|
||||
cdef inline bool AddScorerContext(func, py_context, RF_Scorer* c_context) except False:
|
||||
func._RF_Scorer = PyCapsule_New(c_context, NULL, NULL)
|
||||
func._RF_ScorerPy = py_context
|
||||
return True
|
||||
|
|
|
@ -20,8 +20,9 @@ from array import array
|
|||
|
||||
from rapidfuzz.utils import default_process
|
||||
|
||||
from cpython.pycapsule cimport PyCapsule_New
|
||||
|
||||
from cpp_common cimport (
|
||||
AddScorerContext,
|
||||
CreateScorerContext,
|
||||
CreateScorerContextPy,
|
||||
NoKwargsInit,
|
||||
|
@ -217,31 +218,31 @@ def _GetScorerFlagsSimilarity(**kwargs):
|
|||
cdef dict FuzzContextPy = CreateScorerContextPy(_GetScorerFlagsSimilarity)
|
||||
|
||||
cdef RF_Scorer RatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, RatioInit)
|
||||
AddScorerContext(ratio, FuzzContextPy, &RatioContext)
|
||||
ratio._RF_Scorer = PyCapsule_New(&RatioContext, NULL, NULL)
|
||||
|
||||
cdef RF_Scorer PartialRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialRatioInit)
|
||||
AddScorerContext(partial_ratio, FuzzContextPy, &PartialRatioContext)
|
||||
partial_ratio._RF_Scorer = PyCapsule_New(&PartialRatioContext, NULL, NULL)
|
||||
|
||||
cdef RF_Scorer TokenSortRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, TokenSortRatioInit)
|
||||
AddScorerContext(token_sort_ratio, FuzzContextPy, &TokenSortRatioContext)
|
||||
token_sort_ratio._RF_Scorer = PyCapsule_New(&TokenSortRatioContext, NULL, NULL)
|
||||
|
||||
cdef RF_Scorer TokenSetRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, TokenSetRatioInit)
|
||||
AddScorerContext(token_set_ratio, FuzzContextPy, &TokenSetRatioContext)
|
||||
token_set_ratio._RF_Scorer = PyCapsule_New(&TokenSetRatioContext, NULL, NULL)
|
||||
|
||||
cdef RF_Scorer TokenRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, TokenRatioInit)
|
||||
AddScorerContext(token_ratio, FuzzContextPy, &TokenRatioContext)
|
||||
token_ratio._RF_Scorer = PyCapsule_New(&TokenRatioContext, NULL, NULL)
|
||||
|
||||
cdef RF_Scorer PartialTokenSortRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialTokenSortRatioInit)
|
||||
AddScorerContext(partial_token_sort_ratio, FuzzContextPy, &PartialTokenSortRatioContext)
|
||||
partial_token_sort_ratio._RF_Scorer = PyCapsule_New(&PartialTokenSortRatioContext, NULL, NULL)
|
||||
|
||||
cdef RF_Scorer PartialTokenSetRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialTokenSetRatioInit)
|
||||
AddScorerContext(partial_token_set_ratio, FuzzContextPy, &PartialTokenSetRatioContext)
|
||||
partial_token_set_ratio._RF_Scorer = PyCapsule_New(&PartialTokenSetRatioContext, NULL, NULL)
|
||||
|
||||
cdef RF_Scorer PartialTokenRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialTokenRatioInit)
|
||||
AddScorerContext(partial_token_ratio, FuzzContextPy, &PartialTokenRatioContext)
|
||||
partial_token_ratio._RF_Scorer = PyCapsule_New(&PartialTokenRatioContext, NULL, NULL)
|
||||
|
||||
cdef RF_Scorer WRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, WRatioInit)
|
||||
AddScorerContext(WRatio, FuzzContextPy, &WRatioContext)
|
||||
WRatio._RF_Scorer = PyCapsule_New(&WRatioContext, NULL, NULL)
|
||||
|
||||
cdef RF_Scorer QRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, QRatioInit)
|
||||
AddScorerContext(QRatio, FuzzContextPy, &QRatioContext)
|
||||
QRatio._RF_Scorer = PyCapsule_New(&QRatioContext, NULL, NULL)
|
||||
|
|
|
@ -1482,6 +1482,7 @@ cdef cdist_py(queries, choices, scorer, processor, score_cutoff, dtype, workers,
|
|||
def cdist(queries, choices, *, scorer=ratio, processor=None, score_cutoff=None, score_hint=None, dtype=None, workers=1, **kwargs):
|
||||
cdef RF_Scorer* scorer_context = NULL
|
||||
cdef RF_ScorerFlags scorer_flags
|
||||
cdef bool is_orig_scorer
|
||||
|
||||
if processor is True:
|
||||
# todo: deprecate this
|
||||
|
@ -1493,8 +1494,9 @@ def cdist(queries, choices, *, scorer=ratio, processor=None, score_cutoff=None,
|
|||
if PyCapsule_IsValid(scorer_capsule, NULL):
|
||||
scorer_context = <RF_Scorer*>PyCapsule_GetPointer(scorer_capsule, NULL)
|
||||
|
||||
if scorer_context:
|
||||
if scorer_context.version == SCORER_STRUCT_VERSION:
|
||||
is_orig_scorer = getattr(scorer, '_RF_OriginalScorer', None) is scorer
|
||||
|
||||
if is_orig_scorer and scorer_context and scorer_context.version == SCORER_STRUCT_VERSION:
|
||||
kwargs_context = RF_KwargsWrapper()
|
||||
scorer_context.kwargs_init(&kwargs_context.kwargs, kwargs)
|
||||
scorer_context.get_scorer_flags(&kwargs_context.kwargs, &scorer_flags)
|
||||
|
|
|
@ -2,6 +2,13 @@ import pytest
|
|||
|
||||
from rapidfuzz import fuzz, process_cpp, process_py
|
||||
|
||||
def wrapped(func):
|
||||
from functools import wraps
|
||||
@wraps(func)
|
||||
def decorator(*args, **kwargs):
|
||||
return 100
|
||||
|
||||
return decorator
|
||||
|
||||
class process:
|
||||
@staticmethod
|
||||
|
@ -27,6 +34,16 @@ class process:
|
|||
assert res1 == res2
|
||||
return res1
|
||||
|
||||
@staticmethod
|
||||
def cdist(*args, **kwargs):
|
||||
res1 = process_cpp.cdist(*args, **kwargs)
|
||||
res2 = process_py.cdist(*args, **kwargs)
|
||||
assert res1.dtype == res2.dtype
|
||||
assert res1.shape == res2.shape
|
||||
if res1.size and res2.size:
|
||||
assert res1 == res2
|
||||
return res1
|
||||
|
||||
|
||||
baseball_strings = [
|
||||
"new york mets vs chicago cubs",
|
||||
|
@ -351,7 +368,14 @@ def test_extractOne_use_first_match(scorer):
|
|||
@pytest.mark.parametrize("scorer", [fuzz.ratio, fuzz.WRatio, custom_scorer])
|
||||
def test_cdist_empty_seq(scorer):
|
||||
pytest.importorskip("numpy")
|
||||
assert process_cpp.cdist([], ["a", "b"], scorer=scorer).shape == (0, 2)
|
||||
assert process_cpp.cdist(["a", "b"], [], scorer=scorer).shape == (2, 0)
|
||||
assert process_py.cdist([], ["a", "b"], scorer=scorer).shape == (0, 2)
|
||||
assert process_py.cdist(["a", "b"], [], scorer=scorer).shape == (2, 0)
|
||||
assert process.cdist([], ["a", "b"], scorer=scorer).shape == (0, 2)
|
||||
assert process.cdist(["a", "b"], [], scorer=scorer).shape == (2, 0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("scorer", [fuzz.ratio])
|
||||
def test_wrapped_function(scorer):
|
||||
pytest.importorskip("numpy")
|
||||
scorer = wrapped(scorer)
|
||||
assert process.cdist(["test"], [float("nan")], scorer=scorer)[0, 0] == 100
|
||||
assert process.cdist(["test"], [None], scorer=scorer)[0, 0] == 100
|
||||
assert process.cdist(["test"], ["tes"], scorer=scorer)[0, 0] == 100
|
||||
|
|
Loading…
Reference in New Issue