improve handling of functions wrapped using `functools.wraps`

This commit is contained in:
Max Bachmann 2022-11-29 14:58:30 +01:00
parent e6bc56c1ad
commit 9693b7da76
7 changed files with 62 additions and 35 deletions

View File

@ -1,5 +1,9 @@
## Changelog ## Changelog
### [2.14.0] -
#### Fixed
- improve handling of functions wrapped using `functools.wraps`
### [2.13.2] - 2022-11-05 ### [2.13.2] - 2022-11-05
#### Fixed #### Fixed
- fix incorrect results in `Hamming.normalized_similarity` - fix incorrect results in `Hamming.normalized_similarity`

@ -1 +1 @@
Subproject commit 97cb88437af19ebb095d46cf545cbce23ec2d083 Subproject commit 749d32ad560d5d9c9917dec61f5d28c2b0923a78

View File

@ -61,6 +61,8 @@ def fallback_import(
if cached_scorer_call: if cached_scorer_call:
py_func._RF_ScorerPy = cached_scorer_call py_func._RF_ScorerPy = cached_scorer_call
# used to detect the function hasn't been wrapped afterwards
py_func._RF_OriginalScorer = py_func
if impl == "cpp": if impl == "cpp":
cpp_mod = importlib.import_module(module + "_cpp") cpp_mod = importlib.import_module(module + "_cpp")
@ -85,6 +87,8 @@ def fallback_import(
if cached_scorer_call: if cached_scorer_call:
cpp_func._RF_ScorerPy = cached_scorer_call cpp_func._RF_ScorerPy = cached_scorer_call
# used to detect the function hasn't been wrapped afterwards
cpp_func._RF_OriginalScorer = cpp_func
return cpp_func return cpp_func

View File

@ -409,8 +409,3 @@ cdef inline RF_Scorer CreateScorerContext(RF_KwargsInit kwargs_init, RF_GetScore
cdef inline dict CreateScorerContextPy(get_scorer_flags): cdef inline dict CreateScorerContextPy(get_scorer_flags):
return {"get_scorer_flags": get_scorer_flags} return {"get_scorer_flags": get_scorer_flags}
cdef inline bool AddScorerContext(func, py_context, RF_Scorer* c_context) except False:
func._RF_Scorer = PyCapsule_New(c_context, NULL, NULL)
func._RF_ScorerPy = py_context
return True

View File

@ -20,8 +20,9 @@ from array import array
from rapidfuzz.utils import default_process from rapidfuzz.utils import default_process
from cpython.pycapsule cimport PyCapsule_New
from cpp_common cimport ( from cpp_common cimport (
AddScorerContext,
CreateScorerContext, CreateScorerContext,
CreateScorerContextPy, CreateScorerContextPy,
NoKwargsInit, NoKwargsInit,
@ -217,31 +218,31 @@ def _GetScorerFlagsSimilarity(**kwargs):
cdef dict FuzzContextPy = CreateScorerContextPy(_GetScorerFlagsSimilarity) cdef dict FuzzContextPy = CreateScorerContextPy(_GetScorerFlagsSimilarity)
cdef RF_Scorer RatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, RatioInit) cdef RF_Scorer RatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, RatioInit)
AddScorerContext(ratio, FuzzContextPy, &RatioContext) ratio._RF_Scorer = PyCapsule_New(&RatioContext, NULL, NULL)
cdef RF_Scorer PartialRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialRatioInit) cdef RF_Scorer PartialRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialRatioInit)
AddScorerContext(partial_ratio, FuzzContextPy, &PartialRatioContext) partial_ratio._RF_Scorer = PyCapsule_New(&PartialRatioContext, NULL, NULL)
cdef RF_Scorer TokenSortRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, TokenSortRatioInit) cdef RF_Scorer TokenSortRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, TokenSortRatioInit)
AddScorerContext(token_sort_ratio, FuzzContextPy, &TokenSortRatioContext) token_sort_ratio._RF_Scorer = PyCapsule_New(&TokenSortRatioContext, NULL, NULL)
cdef RF_Scorer TokenSetRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, TokenSetRatioInit) cdef RF_Scorer TokenSetRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, TokenSetRatioInit)
AddScorerContext(token_set_ratio, FuzzContextPy, &TokenSetRatioContext) token_set_ratio._RF_Scorer = PyCapsule_New(&TokenSetRatioContext, NULL, NULL)
cdef RF_Scorer TokenRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, TokenRatioInit) cdef RF_Scorer TokenRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, TokenRatioInit)
AddScorerContext(token_ratio, FuzzContextPy, &TokenRatioContext) token_ratio._RF_Scorer = PyCapsule_New(&TokenRatioContext, NULL, NULL)
cdef RF_Scorer PartialTokenSortRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialTokenSortRatioInit) cdef RF_Scorer PartialTokenSortRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialTokenSortRatioInit)
AddScorerContext(partial_token_sort_ratio, FuzzContextPy, &PartialTokenSortRatioContext) partial_token_sort_ratio._RF_Scorer = PyCapsule_New(&PartialTokenSortRatioContext, NULL, NULL)
cdef RF_Scorer PartialTokenSetRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialTokenSetRatioInit) cdef RF_Scorer PartialTokenSetRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialTokenSetRatioInit)
AddScorerContext(partial_token_set_ratio, FuzzContextPy, &PartialTokenSetRatioContext) partial_token_set_ratio._RF_Scorer = PyCapsule_New(&PartialTokenSetRatioContext, NULL, NULL)
cdef RF_Scorer PartialTokenRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialTokenRatioInit) cdef RF_Scorer PartialTokenRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialTokenRatioInit)
AddScorerContext(partial_token_ratio, FuzzContextPy, &PartialTokenRatioContext) partial_token_ratio._RF_Scorer = PyCapsule_New(&PartialTokenRatioContext, NULL, NULL)
cdef RF_Scorer WRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, WRatioInit) cdef RF_Scorer WRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, WRatioInit)
AddScorerContext(WRatio, FuzzContextPy, &WRatioContext) WRatio._RF_Scorer = PyCapsule_New(&WRatioContext, NULL, NULL)
cdef RF_Scorer QRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, QRatioInit) cdef RF_Scorer QRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, QRatioInit)
AddScorerContext(QRatio, FuzzContextPy, &QRatioContext) QRatio._RF_Scorer = PyCapsule_New(&QRatioContext, NULL, NULL)

View File

@ -1482,6 +1482,7 @@ cdef cdist_py(queries, choices, scorer, processor, score_cutoff, dtype, workers,
def cdist(queries, choices, *, scorer=ratio, processor=None, score_cutoff=None, score_hint=None, dtype=None, workers=1, **kwargs): def cdist(queries, choices, *, scorer=ratio, processor=None, score_cutoff=None, score_hint=None, dtype=None, workers=1, **kwargs):
cdef RF_Scorer* scorer_context = NULL cdef RF_Scorer* scorer_context = NULL
cdef RF_ScorerFlags scorer_flags cdef RF_ScorerFlags scorer_flags
cdef bool is_orig_scorer
if processor is True: if processor is True:
# todo: deprecate this # todo: deprecate this
@ -1493,20 +1494,21 @@ def cdist(queries, choices, *, scorer=ratio, processor=None, score_cutoff=None,
if PyCapsule_IsValid(scorer_capsule, NULL): if PyCapsule_IsValid(scorer_capsule, NULL):
scorer_context = <RF_Scorer*>PyCapsule_GetPointer(scorer_capsule, NULL) scorer_context = <RF_Scorer*>PyCapsule_GetPointer(scorer_capsule, NULL)
if scorer_context: is_orig_scorer = getattr(scorer, '_RF_OriginalScorer', None) is scorer
if scorer_context.version == SCORER_STRUCT_VERSION:
kwargs_context = RF_KwargsWrapper()
scorer_context.kwargs_init(&kwargs_context.kwargs, kwargs)
scorer_context.get_scorer_flags(&kwargs_context.kwargs, &scorer_flags)
# scorer(queries[i], choices[j]) == scorer(queries[j], choices[i]) if is_orig_scorer and scorer_context and scorer_context.version == SCORER_STRUCT_VERSION:
if scorer_flags.flags & RF_SCORER_FLAG_SYMMETRIC and queries is choices: kwargs_context = RF_KwargsWrapper()
return cdist_single_list( scorer_context.kwargs_init(&kwargs_context.kwargs, kwargs)
queries, scorer_context, &scorer_flags, processor, scorer_context.get_scorer_flags(&kwargs_context.kwargs, &scorer_flags)
score_cutoff, score_hint, dtype, workers, &kwargs_context.kwargs)
else: # scorer(queries[i], choices[j]) == scorer(queries[j], choices[i])
return cdist_two_lists( if scorer_flags.flags & RF_SCORER_FLAG_SYMMETRIC and queries is choices:
queries, choices, scorer_context, &scorer_flags, processor, return cdist_single_list(
score_cutoff, score_hint, dtype, workers, &kwargs_context.kwargs) queries, scorer_context, &scorer_flags, processor,
score_cutoff, score_hint, dtype, workers, &kwargs_context.kwargs)
else:
return cdist_two_lists(
queries, choices, scorer_context, &scorer_flags, processor,
score_cutoff, score_hint, dtype, workers, &kwargs_context.kwargs)
return cdist_py(queries, choices, scorer, processor, score_cutoff, dtype, workers, kwargs) return cdist_py(queries, choices, scorer, processor, score_cutoff, dtype, workers, kwargs)

View File

@ -2,6 +2,13 @@ import pytest
from rapidfuzz import fuzz, process_cpp, process_py from rapidfuzz import fuzz, process_cpp, process_py
def wrapped(func):
from functools import wraps
@wraps(func)
def decorator(*args, **kwargs):
return 100
return decorator
class process: class process:
@staticmethod @staticmethod
@ -27,6 +34,13 @@ class process:
assert res1 == res2 assert res1 == res2
return res1 return res1
@staticmethod
def cdist(*args, **kwargs):
res1 = process_cpp.cdist(*args, **kwargs)
res2 = process_py.cdist(*args, **kwargs)
assert res1 == res2
return res1
baseball_strings = [ baseball_strings = [
"new york mets vs chicago cubs", "new york mets vs chicago cubs",
@ -351,7 +365,14 @@ def test_extractOne_use_first_match(scorer):
@pytest.mark.parametrize("scorer", [fuzz.ratio, fuzz.WRatio, custom_scorer]) @pytest.mark.parametrize("scorer", [fuzz.ratio, fuzz.WRatio, custom_scorer])
def test_cdist_empty_seq(scorer): def test_cdist_empty_seq(scorer):
pytest.importorskip("numpy") pytest.importorskip("numpy")
assert process_cpp.cdist([], ["a", "b"], scorer=scorer).shape == (0, 2) assert process.cdist([], ["a", "b"], scorer=scorer).shape == (0, 2)
assert process_cpp.cdist(["a", "b"], [], scorer=scorer).shape == (2, 0) assert process.cdist(["a", "b"], [], scorer=scorer).shape == (2, 0)
assert process_py.cdist([], ["a", "b"], scorer=scorer).shape == (0, 2)
assert process_py.cdist(["a", "b"], [], scorer=scorer).shape == (2, 0)
@pytest.mark.parametrize("scorer", [fuzz.ratio])
def test_wrapped_function(scorer):
pytest.importorskip("numpy")
scorer = wrapped(scorer)
assert process.cdist(["test"], [float("nan")], scorer=scorer)[0, 0] == 100
assert process.cdist(["test"], [None], scorer=scorer)[0, 0] == 100
assert process.cdist(["test"], ["tes"], scorer=scorer)[0, 0] == 100