Fix incorrect normalization for legacy scorers

This commit is contained in:
Max Bachmann 2022-01-02 15:44:28 +01:00
parent 73f99a4ee9
commit 04969928a7
1 changed files with 54 additions and 4 deletions

View File

@ -1,6 +1,56 @@
#pragma once #pragma once
#include "cpp_common.hpp" #include "cpp_common.hpp"
template<typename CachedScorer>
static inline bool legacy_scorer_func_wrapper_f64(const RF_ScorerFunc* self, const RF_String* str, double score_cutoff, double* result)
{
CachedScorer& scorer = *(CachedScorer*)self->context;
try {
*result = visit(*str, [&](auto s){
return scorer.ratio(s, score_cutoff);
}) * 100;
} catch(...) {
PyGILState_STATE gilstate_save = PyGILState_Ensure();
CppExn2PyErr();
PyGILState_Release(gilstate_save);
return false;
}
return true;
}
template<template <typename> class CachedScorer, typename Sentence, typename ...Args>
static inline RF_ScorerFunc legacy_get_ScorerContext_f64(Sentence str, Args... args)
{
RF_ScorerFunc context;
context.context = (void*) new CachedScorer<Sentence>(str, args...);
context.call.f64 = legacy_scorer_func_wrapper_f64<CachedScorer<Sentence>>;
context.dtor = scorer_deinit<CachedScorer<Sentence>>;
return context;
}
template<template <typename> class CachedScorer, typename ...Args>
static inline bool legacy_scorer_init_f64(RF_ScorerFunc* self, size_t str_count, const RF_String* strings, Args... args)
{
try {
/* todo support different string counts, which is required e.g. for SIMD */
if (str_count != 1)
{
throw std::logic_error("Only str_count == 1 supported");
}
*self = visit(*strings, [&](auto s){
return legacy_get_ScorerContext_f64<CachedScorer>(s, args...);
});
} catch(...) {
PyGILState_STATE gilstate_save = PyGILState_Ensure();
CppExn2PyErr();
PyGILState_Release(gilstate_save);
return false;
}
return true;
}
static inline size_t levenshtein_func(const RF_String& s1, const RF_String& s2, static inline size_t levenshtein_func(const RF_String& s1, const RF_String& s2,
size_t insertion, size_t deletion, size_t substitution, size_t max) size_t insertion, size_t deletion, size_t substitution, size_t max)
{ {
@ -26,7 +76,7 @@ static inline double normalized_levenshtein_func(const RF_String& s1, const RF_S
} }
static inline bool NormalizedLevenshteinInit(RF_ScorerFunc* self, const RF_Kwargs* kwargs, size_t str_count, const RF_String* str) static inline bool NormalizedLevenshteinInit(RF_ScorerFunc* self, const RF_Kwargs* kwargs, size_t str_count, const RF_String* str)
{ {
return scorer_init_f64<string_metric::CachedNormalizedLevenshtein>( return legacy_scorer_init_f64<string_metric::CachedNormalizedLevenshtein>(
self, str_count, str, *(rapidfuzz::LevenshteinWeightTable*)(kwargs->context) self, str_count, str, *(rapidfuzz::LevenshteinWeightTable*)(kwargs->context)
); );
} }
@ -50,7 +100,7 @@ static inline double normalized_hamming_func(const RF_String& s1, const RF_Strin
} }
static inline bool NormalizedHammingInit(RF_ScorerFunc* self, const RF_Kwargs*, size_t str_count, const RF_String* str) static inline bool NormalizedHammingInit(RF_ScorerFunc* self, const RF_Kwargs*, size_t str_count, const RF_String* str)
{ {
return scorer_init_f64<string_metric::CachedNormalizedHamming>(self, str_count, str); return legacy_scorer_init_f64<string_metric::CachedNormalizedHamming>(self, str_count, str);
} }
static inline double jaro_similarity_func(const RF_String& s1, const RF_String& s2, double score_cutoff) static inline double jaro_similarity_func(const RF_String& s1, const RF_String& s2, double score_cutoff)
@ -61,7 +111,7 @@ static inline double jaro_similarity_func(const RF_String& s1, const RF_String&
} }
static inline bool JaroSimilarityInit(RF_ScorerFunc* self, const RF_Kwargs*, size_t str_count, const RF_String* str) static inline bool JaroSimilarityInit(RF_ScorerFunc* self, const RF_Kwargs*, size_t str_count, const RF_String* str)
{ {
return scorer_init_f64<string_metric::CachedJaroSimilarity>(self, str_count, str); return legacy_scorer_init_f64<string_metric::CachedJaroSimilarity>(self, str_count, str);
} }
static inline double jaro_winkler_similarity_func(const RF_String& s1, const RF_String& s2, static inline double jaro_winkler_similarity_func(const RF_String& s1, const RF_String& s2,
@ -73,7 +123,7 @@ static inline double jaro_winkler_similarity_func(const RF_String& s1, const RF_
} }
static inline bool JaroWinklerSimilarityInit(RF_ScorerFunc* self, const RF_Kwargs* kwargs, size_t str_count, const RF_String* str) static inline bool JaroWinklerSimilarityInit(RF_ScorerFunc* self, const RF_Kwargs* kwargs, size_t str_count, const RF_String* str)
{ {
return scorer_init_f64<string_metric::CachedJaroWinklerSimilarity>(self, str_count, str, *(double*)(kwargs->context)); return legacy_scorer_init_f64<string_metric::CachedJaroWinklerSimilarity>(self, str_count, str, *(double*)(kwargs->context));
} }
static inline std::vector<rapidfuzz::LevenshteinEditOp> levenshtein_editops_func( static inline std::vector<rapidfuzz::LevenshteinEditOp> levenshtein_editops_func(