improve handling of functions wrapped using `functools.wraps`
This commit is contained in:
parent
e6bc56c1ad
commit
9693b7da76
|
@ -1,5 +1,9 @@
|
||||||
## Changelog
|
## Changelog
|
||||||
|
|
||||||
|
### [2.14.0] -
|
||||||
|
#### Fixed
|
||||||
|
- improve handling of functions wrapped using `functools.wraps`
|
||||||
|
|
||||||
### [2.13.2] - 2022-11-05
|
### [2.13.2] - 2022-11-05
|
||||||
#### Fixed
|
#### Fixed
|
||||||
- fix incorrect results in `Hamming.normalized_similarity`
|
- fix incorrect results in `Hamming.normalized_similarity`
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit 97cb88437af19ebb095d46cf545cbce23ec2d083
|
Subproject commit 749d32ad560d5d9c9917dec61f5d28c2b0923a78
|
|
@ -61,6 +61,8 @@ def fallback_import(
|
||||||
|
|
||||||
if cached_scorer_call:
|
if cached_scorer_call:
|
||||||
py_func._RF_ScorerPy = cached_scorer_call
|
py_func._RF_ScorerPy = cached_scorer_call
|
||||||
|
# used to detect the function hasn't been wrapped afterwards
|
||||||
|
py_func._RF_OriginalScorer = py_func
|
||||||
|
|
||||||
if impl == "cpp":
|
if impl == "cpp":
|
||||||
cpp_mod = importlib.import_module(module + "_cpp")
|
cpp_mod = importlib.import_module(module + "_cpp")
|
||||||
|
@ -85,6 +87,8 @@ def fallback_import(
|
||||||
|
|
||||||
if cached_scorer_call:
|
if cached_scorer_call:
|
||||||
cpp_func._RF_ScorerPy = cached_scorer_call
|
cpp_func._RF_ScorerPy = cached_scorer_call
|
||||||
|
# used to detect the function hasn't been wrapped afterwards
|
||||||
|
cpp_func._RF_OriginalScorer = cpp_func
|
||||||
|
|
||||||
return cpp_func
|
return cpp_func
|
||||||
|
|
||||||
|
|
|
@ -409,8 +409,3 @@ cdef inline RF_Scorer CreateScorerContext(RF_KwargsInit kwargs_init, RF_GetScore
|
||||||
|
|
||||||
cdef inline dict CreateScorerContextPy(get_scorer_flags):
|
cdef inline dict CreateScorerContextPy(get_scorer_flags):
|
||||||
return {"get_scorer_flags": get_scorer_flags}
|
return {"get_scorer_flags": get_scorer_flags}
|
||||||
|
|
||||||
cdef inline bool AddScorerContext(func, py_context, RF_Scorer* c_context) except False:
|
|
||||||
func._RF_Scorer = PyCapsule_New(c_context, NULL, NULL)
|
|
||||||
func._RF_ScorerPy = py_context
|
|
||||||
return True
|
|
||||||
|
|
|
@ -20,8 +20,9 @@ from array import array
|
||||||
|
|
||||||
from rapidfuzz.utils import default_process
|
from rapidfuzz.utils import default_process
|
||||||
|
|
||||||
|
from cpython.pycapsule cimport PyCapsule_New
|
||||||
|
|
||||||
from cpp_common cimport (
|
from cpp_common cimport (
|
||||||
AddScorerContext,
|
|
||||||
CreateScorerContext,
|
CreateScorerContext,
|
||||||
CreateScorerContextPy,
|
CreateScorerContextPy,
|
||||||
NoKwargsInit,
|
NoKwargsInit,
|
||||||
|
@ -217,31 +218,31 @@ def _GetScorerFlagsSimilarity(**kwargs):
|
||||||
cdef dict FuzzContextPy = CreateScorerContextPy(_GetScorerFlagsSimilarity)
|
cdef dict FuzzContextPy = CreateScorerContextPy(_GetScorerFlagsSimilarity)
|
||||||
|
|
||||||
cdef RF_Scorer RatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, RatioInit)
|
cdef RF_Scorer RatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, RatioInit)
|
||||||
AddScorerContext(ratio, FuzzContextPy, &RatioContext)
|
ratio._RF_Scorer = PyCapsule_New(&RatioContext, NULL, NULL)
|
||||||
|
|
||||||
cdef RF_Scorer PartialRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialRatioInit)
|
cdef RF_Scorer PartialRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialRatioInit)
|
||||||
AddScorerContext(partial_ratio, FuzzContextPy, &PartialRatioContext)
|
partial_ratio._RF_Scorer = PyCapsule_New(&PartialRatioContext, NULL, NULL)
|
||||||
|
|
||||||
cdef RF_Scorer TokenSortRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, TokenSortRatioInit)
|
cdef RF_Scorer TokenSortRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, TokenSortRatioInit)
|
||||||
AddScorerContext(token_sort_ratio, FuzzContextPy, &TokenSortRatioContext)
|
token_sort_ratio._RF_Scorer = PyCapsule_New(&TokenSortRatioContext, NULL, NULL)
|
||||||
|
|
||||||
cdef RF_Scorer TokenSetRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, TokenSetRatioInit)
|
cdef RF_Scorer TokenSetRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, TokenSetRatioInit)
|
||||||
AddScorerContext(token_set_ratio, FuzzContextPy, &TokenSetRatioContext)
|
token_set_ratio._RF_Scorer = PyCapsule_New(&TokenSetRatioContext, NULL, NULL)
|
||||||
|
|
||||||
cdef RF_Scorer TokenRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, TokenRatioInit)
|
cdef RF_Scorer TokenRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, TokenRatioInit)
|
||||||
AddScorerContext(token_ratio, FuzzContextPy, &TokenRatioContext)
|
token_ratio._RF_Scorer = PyCapsule_New(&TokenRatioContext, NULL, NULL)
|
||||||
|
|
||||||
cdef RF_Scorer PartialTokenSortRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialTokenSortRatioInit)
|
cdef RF_Scorer PartialTokenSortRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialTokenSortRatioInit)
|
||||||
AddScorerContext(partial_token_sort_ratio, FuzzContextPy, &PartialTokenSortRatioContext)
|
partial_token_sort_ratio._RF_Scorer = PyCapsule_New(&PartialTokenSortRatioContext, NULL, NULL)
|
||||||
|
|
||||||
cdef RF_Scorer PartialTokenSetRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialTokenSetRatioInit)
|
cdef RF_Scorer PartialTokenSetRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialTokenSetRatioInit)
|
||||||
AddScorerContext(partial_token_set_ratio, FuzzContextPy, &PartialTokenSetRatioContext)
|
partial_token_set_ratio._RF_Scorer = PyCapsule_New(&PartialTokenSetRatioContext, NULL, NULL)
|
||||||
|
|
||||||
cdef RF_Scorer PartialTokenRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialTokenRatioInit)
|
cdef RF_Scorer PartialTokenRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialTokenRatioInit)
|
||||||
AddScorerContext(partial_token_ratio, FuzzContextPy, &PartialTokenRatioContext)
|
partial_token_ratio._RF_Scorer = PyCapsule_New(&PartialTokenRatioContext, NULL, NULL)
|
||||||
|
|
||||||
cdef RF_Scorer WRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, WRatioInit)
|
cdef RF_Scorer WRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, WRatioInit)
|
||||||
AddScorerContext(WRatio, FuzzContextPy, &WRatioContext)
|
WRatio._RF_Scorer = PyCapsule_New(&WRatioContext, NULL, NULL)
|
||||||
|
|
||||||
cdef RF_Scorer QRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, QRatioInit)
|
cdef RF_Scorer QRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, QRatioInit)
|
||||||
AddScorerContext(QRatio, FuzzContextPy, &QRatioContext)
|
QRatio._RF_Scorer = PyCapsule_New(&QRatioContext, NULL, NULL)
|
||||||
|
|
|
@ -1482,6 +1482,7 @@ cdef cdist_py(queries, choices, scorer, processor, score_cutoff, dtype, workers,
|
||||||
def cdist(queries, choices, *, scorer=ratio, processor=None, score_cutoff=None, score_hint=None, dtype=None, workers=1, **kwargs):
|
def cdist(queries, choices, *, scorer=ratio, processor=None, score_cutoff=None, score_hint=None, dtype=None, workers=1, **kwargs):
|
||||||
cdef RF_Scorer* scorer_context = NULL
|
cdef RF_Scorer* scorer_context = NULL
|
||||||
cdef RF_ScorerFlags scorer_flags
|
cdef RF_ScorerFlags scorer_flags
|
||||||
|
cdef bool is_orig_scorer
|
||||||
|
|
||||||
if processor is True:
|
if processor is True:
|
||||||
# todo: deprecate this
|
# todo: deprecate this
|
||||||
|
@ -1493,20 +1494,21 @@ def cdist(queries, choices, *, scorer=ratio, processor=None, score_cutoff=None,
|
||||||
if PyCapsule_IsValid(scorer_capsule, NULL):
|
if PyCapsule_IsValid(scorer_capsule, NULL):
|
||||||
scorer_context = <RF_Scorer*>PyCapsule_GetPointer(scorer_capsule, NULL)
|
scorer_context = <RF_Scorer*>PyCapsule_GetPointer(scorer_capsule, NULL)
|
||||||
|
|
||||||
if scorer_context:
|
is_orig_scorer = getattr(scorer, '_RF_OriginalScorer', None) is scorer
|
||||||
if scorer_context.version == SCORER_STRUCT_VERSION:
|
|
||||||
kwargs_context = RF_KwargsWrapper()
|
|
||||||
scorer_context.kwargs_init(&kwargs_context.kwargs, kwargs)
|
|
||||||
scorer_context.get_scorer_flags(&kwargs_context.kwargs, &scorer_flags)
|
|
||||||
|
|
||||||
# scorer(queries[i], choices[j]) == scorer(queries[j], choices[i])
|
if is_orig_scorer and scorer_context and scorer_context.version == SCORER_STRUCT_VERSION:
|
||||||
if scorer_flags.flags & RF_SCORER_FLAG_SYMMETRIC and queries is choices:
|
kwargs_context = RF_KwargsWrapper()
|
||||||
return cdist_single_list(
|
scorer_context.kwargs_init(&kwargs_context.kwargs, kwargs)
|
||||||
queries, scorer_context, &scorer_flags, processor,
|
scorer_context.get_scorer_flags(&kwargs_context.kwargs, &scorer_flags)
|
||||||
score_cutoff, score_hint, dtype, workers, &kwargs_context.kwargs)
|
|
||||||
else:
|
# scorer(queries[i], choices[j]) == scorer(queries[j], choices[i])
|
||||||
return cdist_two_lists(
|
if scorer_flags.flags & RF_SCORER_FLAG_SYMMETRIC and queries is choices:
|
||||||
queries, choices, scorer_context, &scorer_flags, processor,
|
return cdist_single_list(
|
||||||
score_cutoff, score_hint, dtype, workers, &kwargs_context.kwargs)
|
queries, scorer_context, &scorer_flags, processor,
|
||||||
|
score_cutoff, score_hint, dtype, workers, &kwargs_context.kwargs)
|
||||||
|
else:
|
||||||
|
return cdist_two_lists(
|
||||||
|
queries, choices, scorer_context, &scorer_flags, processor,
|
||||||
|
score_cutoff, score_hint, dtype, workers, &kwargs_context.kwargs)
|
||||||
|
|
||||||
return cdist_py(queries, choices, scorer, processor, score_cutoff, dtype, workers, kwargs)
|
return cdist_py(queries, choices, scorer, processor, score_cutoff, dtype, workers, kwargs)
|
||||||
|
|
|
@ -2,6 +2,13 @@ import pytest
|
||||||
|
|
||||||
from rapidfuzz import fuzz, process_cpp, process_py
|
from rapidfuzz import fuzz, process_cpp, process_py
|
||||||
|
|
||||||
|
def wrapped(func):
|
||||||
|
from functools import wraps
|
||||||
|
@wraps(func)
|
||||||
|
def decorator(*args, **kwargs):
|
||||||
|
return 100
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
class process:
|
class process:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -27,6 +34,13 @@ class process:
|
||||||
assert res1 == res2
|
assert res1 == res2
|
||||||
return res1
|
return res1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cdist(*args, **kwargs):
|
||||||
|
res1 = process_cpp.cdist(*args, **kwargs)
|
||||||
|
res2 = process_py.cdist(*args, **kwargs)
|
||||||
|
assert res1 == res2
|
||||||
|
return res1
|
||||||
|
|
||||||
|
|
||||||
baseball_strings = [
|
baseball_strings = [
|
||||||
"new york mets vs chicago cubs",
|
"new york mets vs chicago cubs",
|
||||||
|
@ -351,7 +365,14 @@ def test_extractOne_use_first_match(scorer):
|
||||||
@pytest.mark.parametrize("scorer", [fuzz.ratio, fuzz.WRatio, custom_scorer])
|
@pytest.mark.parametrize("scorer", [fuzz.ratio, fuzz.WRatio, custom_scorer])
|
||||||
def test_cdist_empty_seq(scorer):
|
def test_cdist_empty_seq(scorer):
|
||||||
pytest.importorskip("numpy")
|
pytest.importorskip("numpy")
|
||||||
assert process_cpp.cdist([], ["a", "b"], scorer=scorer).shape == (0, 2)
|
assert process.cdist([], ["a", "b"], scorer=scorer).shape == (0, 2)
|
||||||
assert process_cpp.cdist(["a", "b"], [], scorer=scorer).shape == (2, 0)
|
assert process.cdist(["a", "b"], [], scorer=scorer).shape == (2, 0)
|
||||||
assert process_py.cdist([], ["a", "b"], scorer=scorer).shape == (0, 2)
|
|
||||||
assert process_py.cdist(["a", "b"], [], scorer=scorer).shape == (2, 0)
|
|
||||||
|
@pytest.mark.parametrize("scorer", [fuzz.ratio])
|
||||||
|
def test_wrapped_function(scorer):
|
||||||
|
pytest.importorskip("numpy")
|
||||||
|
scorer = wrapped(scorer)
|
||||||
|
assert process.cdist(["test"], [float("nan")], scorer=scorer)[0, 0] == 100
|
||||||
|
assert process.cdist(["test"], [None], scorer=scorer)[0, 0] == 100
|
||||||
|
assert process.cdist(["test"], ["tes"], scorer=scorer)[0, 0] == 100
|
||||||
|
|
Loading…
Reference in New Issue