From 9693b7da76784548d39f66dd1d5d7787b266d4cb Mon Sep 17 00:00:00 2001 From: Max Bachmann Date: Tue, 29 Nov 2022 14:58:30 +0100 Subject: [PATCH] improve handling of functions wrapped using `functools.wraps` --- CHANGELOG.md | 4 ++++ extern/rapidfuzz-cpp | 2 +- src/rapidfuzz/_utils.py | 4 ++++ src/rapidfuzz/cpp_common.pxd | 5 ----- src/rapidfuzz/fuzz_cpp.pyx | 23 ++++++++++++----------- src/rapidfuzz/process_cpp_impl.pyx | 30 ++++++++++++++++-------------- tests/test_process.py | 29 +++++++++++++++++++++++++---- 7 files changed, 62 insertions(+), 35 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c32fbbd..3491f26 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## Changelog +### [2.14.0] - +#### Fixed +- improve handling of functions wrapped using `functools.wraps` + ### [2.13.2] - 2022-11-05 #### Fixed - fix incorrect results in `Hamming.normalized_similarity` diff --git a/extern/rapidfuzz-cpp b/extern/rapidfuzz-cpp index 97cb884..749d32a 160000 --- a/extern/rapidfuzz-cpp +++ b/extern/rapidfuzz-cpp @@ -1 +1 @@ -Subproject commit 97cb88437af19ebb095d46cf545cbce23ec2d083 +Subproject commit 749d32ad560d5d9c9917dec61f5d28c2b0923a78 diff --git a/src/rapidfuzz/_utils.py b/src/rapidfuzz/_utils.py index 1c176b1..b9d2813 100644 --- a/src/rapidfuzz/_utils.py +++ b/src/rapidfuzz/_utils.py @@ -61,6 +61,8 @@ def fallback_import( if cached_scorer_call: py_func._RF_ScorerPy = cached_scorer_call + # used to detect the function hasn't been wrapped afterwards + py_func._RF_OriginalScorer = py_func if impl == "cpp": cpp_mod = importlib.import_module(module + "_cpp") @@ -85,6 +87,8 @@ def fallback_import( if cached_scorer_call: cpp_func._RF_ScorerPy = cached_scorer_call + # used to detect the function hasn't been wrapped afterwards + cpp_func._RF_OriginalScorer = cpp_func return cpp_func diff --git a/src/rapidfuzz/cpp_common.pxd b/src/rapidfuzz/cpp_common.pxd index 5729a99..73b68de 100644 --- a/src/rapidfuzz/cpp_common.pxd +++ b/src/rapidfuzz/cpp_common.pxd @@ -409,8 +409,3 @@ cdef inline RF_Scorer CreateScorerContext(RF_KwargsInit kwargs_init, RF_GetScore cdef inline dict CreateScorerContextPy(get_scorer_flags): return {"get_scorer_flags": get_scorer_flags} - -cdef inline bool AddScorerContext(func, py_context, RF_Scorer* c_context) except False: - func._RF_Scorer = PyCapsule_New(c_context, NULL, NULL) - func._RF_ScorerPy = py_context - return True diff --git a/src/rapidfuzz/fuzz_cpp.pyx b/src/rapidfuzz/fuzz_cpp.pyx index a20b200..7c00a34 100644 --- a/src/rapidfuzz/fuzz_cpp.pyx +++ b/src/rapidfuzz/fuzz_cpp.pyx @@ -20,8 +20,9 @@ from array import array from rapidfuzz.utils import default_process +from cpython.pycapsule cimport PyCapsule_New + from cpp_common cimport ( - AddScorerContext, CreateScorerContext, CreateScorerContextPy, NoKwargsInit, @@ -217,31 +218,31 @@ def _GetScorerFlagsSimilarity(**kwargs): cdef dict FuzzContextPy = CreateScorerContextPy(_GetScorerFlagsSimilarity) cdef RF_Scorer RatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, RatioInit) -AddScorerContext(ratio, FuzzContextPy, &RatioContext) +ratio._RF_Scorer = PyCapsule_New(&RatioContext, NULL, NULL) cdef RF_Scorer PartialRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialRatioInit) -AddScorerContext(partial_ratio, FuzzContextPy, &PartialRatioContext) +partial_ratio._RF_Scorer = PyCapsule_New(&PartialRatioContext, NULL, NULL) cdef RF_Scorer TokenSortRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, TokenSortRatioInit) -AddScorerContext(token_sort_ratio, FuzzContextPy, &TokenSortRatioContext) +token_sort_ratio._RF_Scorer = PyCapsule_New(&TokenSortRatioContext, NULL, NULL) cdef RF_Scorer TokenSetRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, TokenSetRatioInit) -AddScorerContext(token_set_ratio, FuzzContextPy, &TokenSetRatioContext) +token_set_ratio._RF_Scorer = PyCapsule_New(&TokenSetRatioContext, NULL, NULL) cdef RF_Scorer TokenRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, TokenRatioInit) -AddScorerContext(token_ratio, FuzzContextPy, &TokenRatioContext) +token_ratio._RF_Scorer = PyCapsule_New(&TokenRatioContext, NULL, NULL) cdef RF_Scorer PartialTokenSortRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialTokenSortRatioInit) -AddScorerContext(partial_token_sort_ratio, FuzzContextPy, &PartialTokenSortRatioContext) +partial_token_sort_ratio._RF_Scorer = PyCapsule_New(&PartialTokenSortRatioContext, NULL, NULL) cdef RF_Scorer PartialTokenSetRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialTokenSetRatioInit) -AddScorerContext(partial_token_set_ratio, FuzzContextPy, &PartialTokenSetRatioContext) +partial_token_set_ratio._RF_Scorer = PyCapsule_New(&PartialTokenSetRatioContext, NULL, NULL) cdef RF_Scorer PartialTokenRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, PartialTokenRatioInit) -AddScorerContext(partial_token_ratio, FuzzContextPy, &PartialTokenRatioContext) +partial_token_ratio._RF_Scorer = PyCapsule_New(&PartialTokenRatioContext, NULL, NULL) cdef RF_Scorer WRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzz, WRatioInit) -AddScorerContext(WRatio, FuzzContextPy, &WRatioContext) +WRatio._RF_Scorer = PyCapsule_New(&WRatioContext, NULL, NULL) cdef RF_Scorer QRatioContext = CreateScorerContext(NoKwargsInit, GetScorerFlagsFuzzRatio, QRatioInit) -AddScorerContext(QRatio, FuzzContextPy, &QRatioContext) +QRatio._RF_Scorer = PyCapsule_New(&QRatioContext, NULL, NULL) diff --git a/src/rapidfuzz/process_cpp_impl.pyx b/src/rapidfuzz/process_cpp_impl.pyx index 77eb13e..eaee5cf 100644 --- a/src/rapidfuzz/process_cpp_impl.pyx +++ b/src/rapidfuzz/process_cpp_impl.pyx @@ -1482,6 +1482,7 @@ cdef cdist_py(queries, choices, scorer, processor, score_cutoff, dtype, workers, def cdist(queries, choices, *, scorer=ratio, processor=None, score_cutoff=None, score_hint=None, dtype=None, workers=1, **kwargs): cdef RF_Scorer* scorer_context = NULL cdef RF_ScorerFlags scorer_flags + cdef bool is_orig_scorer if processor is True: # todo: deprecate this @@ -1493,20 +1494,21 @@ def cdist(queries, choices, *, scorer=ratio, processor=None, score_cutoff=None, if PyCapsule_IsValid(scorer_capsule, NULL): scorer_context = PyCapsule_GetPointer(scorer_capsule, NULL) - if scorer_context: - if scorer_context.version == SCORER_STRUCT_VERSION: - kwargs_context = RF_KwargsWrapper() - scorer_context.kwargs_init(&kwargs_context.kwargs, kwargs) - scorer_context.get_scorer_flags(&kwargs_context.kwargs, &scorer_flags) + is_orig_scorer = getattr(scorer, '_RF_OriginalScorer', None) is scorer - # scorer(queries[i], choices[j]) == scorer(queries[j], choices[i]) - if scorer_flags.flags & RF_SCORER_FLAG_SYMMETRIC and queries is choices: - return cdist_single_list( - queries, scorer_context, &scorer_flags, processor, - score_cutoff, score_hint, dtype, workers, &kwargs_context.kwargs) - else: - return cdist_two_lists( - queries, choices, scorer_context, &scorer_flags, processor, - score_cutoff, score_hint, dtype, workers, &kwargs_context.kwargs) + if is_orig_scorer and scorer_context and scorer_context.version == SCORER_STRUCT_VERSION: + kwargs_context = RF_KwargsWrapper() + scorer_context.kwargs_init(&kwargs_context.kwargs, kwargs) + scorer_context.get_scorer_flags(&kwargs_context.kwargs, &scorer_flags) + + # scorer(queries[i], choices[j]) == scorer(queries[j], choices[i]) + if scorer_flags.flags & RF_SCORER_FLAG_SYMMETRIC and queries is choices: + return cdist_single_list( + queries, scorer_context, &scorer_flags, processor, + score_cutoff, score_hint, dtype, workers, &kwargs_context.kwargs) + else: + return cdist_two_lists( + queries, choices, scorer_context, &scorer_flags, processor, + score_cutoff, score_hint, dtype, workers, &kwargs_context.kwargs) return cdist_py(queries, choices, scorer, processor, score_cutoff, dtype, workers, kwargs) diff --git a/tests/test_process.py b/tests/test_process.py index fcc1fa5..ea969b4 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -2,6 +2,13 @@ import pytest from rapidfuzz import fuzz, process_cpp, process_py +def wrapped(func): + from functools import wraps + @wraps(func) + def decorator(*args, **kwargs): + return 100 + + return decorator class process: @staticmethod @@ -27,6 +34,13 @@ class process: assert res1 == res2 return res1 + @staticmethod + def cdist(*args, **kwargs): + res1 = process_cpp.cdist(*args, **kwargs) + res2 = process_py.cdist(*args, **kwargs) + assert res1 == res2 + return res1 + baseball_strings = [ "new york mets vs chicago cubs", @@ -351,7 +365,14 @@ def test_extractOne_use_first_match(scorer): @pytest.mark.parametrize("scorer", [fuzz.ratio, fuzz.WRatio, custom_scorer]) def test_cdist_empty_seq(scorer): pytest.importorskip("numpy") - assert process_cpp.cdist([], ["a", "b"], scorer=scorer).shape == (0, 2) - assert process_cpp.cdist(["a", "b"], [], scorer=scorer).shape == (2, 0) - assert process_py.cdist([], ["a", "b"], scorer=scorer).shape == (0, 2) - assert process_py.cdist(["a", "b"], [], scorer=scorer).shape == (2, 0) + assert process.cdist([], ["a", "b"], scorer=scorer).shape == (0, 2) + assert process.cdist(["a", "b"], [], scorer=scorer).shape == (2, 0) + + +@pytest.mark.parametrize("scorer", [fuzz.ratio]) +def test_wrapped_function(scorer): + pytest.importorskip("numpy") + scorer = wrapped(scorer) + assert process.cdist(["test"], [float("nan")], scorer=scorer)[0, 0] == 100 + assert process.cdist(["test"], [None], scorer=scorer)[0, 0] == 100 + assert process.cdist(["test"], ["tes"], scorer=scorer)[0, 0] == 100