Fix incorrect normalization for legacy scorers
This commit is contained in:
parent
73f99a4ee9
commit
04969928a7
|
@ -1,6 +1,56 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include "cpp_common.hpp"
|
#include "cpp_common.hpp"
|
||||||
|
|
||||||
|
|
||||||
|
template<typename CachedScorer>
|
||||||
|
static inline bool legacy_scorer_func_wrapper_f64(const RF_ScorerFunc* self, const RF_String* str, double score_cutoff, double* result)
|
||||||
|
{
|
||||||
|
CachedScorer& scorer = *(CachedScorer*)self->context;
|
||||||
|
try {
|
||||||
|
*result = visit(*str, [&](auto s){
|
||||||
|
return scorer.ratio(s, score_cutoff);
|
||||||
|
}) * 100;
|
||||||
|
} catch(...) {
|
||||||
|
PyGILState_STATE gilstate_save = PyGILState_Ensure();
|
||||||
|
CppExn2PyErr();
|
||||||
|
PyGILState_Release(gilstate_save);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<template <typename> class CachedScorer, typename Sentence, typename ...Args>
|
||||||
|
static inline RF_ScorerFunc legacy_get_ScorerContext_f64(Sentence str, Args... args)
|
||||||
|
{
|
||||||
|
RF_ScorerFunc context;
|
||||||
|
context.context = (void*) new CachedScorer<Sentence>(str, args...);
|
||||||
|
|
||||||
|
context.call.f64 = legacy_scorer_func_wrapper_f64<CachedScorer<Sentence>>;
|
||||||
|
context.dtor = scorer_deinit<CachedScorer<Sentence>>;
|
||||||
|
return context;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<template <typename> class CachedScorer, typename ...Args>
|
||||||
|
static inline bool legacy_scorer_init_f64(RF_ScorerFunc* self, size_t str_count, const RF_String* strings, Args... args)
|
||||||
|
{
|
||||||
|
try {
|
||||||
|
/* todo support different string counts, which is required e.g. for SIMD */
|
||||||
|
if (str_count != 1)
|
||||||
|
{
|
||||||
|
throw std::logic_error("Only str_count == 1 supported");
|
||||||
|
}
|
||||||
|
*self = visit(*strings, [&](auto s){
|
||||||
|
return legacy_get_ScorerContext_f64<CachedScorer>(s, args...);
|
||||||
|
});
|
||||||
|
} catch(...) {
|
||||||
|
PyGILState_STATE gilstate_save = PyGILState_Ensure();
|
||||||
|
CppExn2PyErr();
|
||||||
|
PyGILState_Release(gilstate_save);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
static inline size_t levenshtein_func(const RF_String& s1, const RF_String& s2,
|
static inline size_t levenshtein_func(const RF_String& s1, const RF_String& s2,
|
||||||
size_t insertion, size_t deletion, size_t substitution, size_t max)
|
size_t insertion, size_t deletion, size_t substitution, size_t max)
|
||||||
{
|
{
|
||||||
|
@ -26,7 +76,7 @@ static inline double normalized_levenshtein_func(const RF_String& s1, const RF_S
|
||||||
}
|
}
|
||||||
static inline bool NormalizedLevenshteinInit(RF_ScorerFunc* self, const RF_Kwargs* kwargs, size_t str_count, const RF_String* str)
|
static inline bool NormalizedLevenshteinInit(RF_ScorerFunc* self, const RF_Kwargs* kwargs, size_t str_count, const RF_String* str)
|
||||||
{
|
{
|
||||||
return scorer_init_f64<string_metric::CachedNormalizedLevenshtein>(
|
return legacy_scorer_init_f64<string_metric::CachedNormalizedLevenshtein>(
|
||||||
self, str_count, str, *(rapidfuzz::LevenshteinWeightTable*)(kwargs->context)
|
self, str_count, str, *(rapidfuzz::LevenshteinWeightTable*)(kwargs->context)
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -50,7 +100,7 @@ static inline double normalized_hamming_func(const RF_String& s1, const RF_Strin
|
||||||
}
|
}
|
||||||
static inline bool NormalizedHammingInit(RF_ScorerFunc* self, const RF_Kwargs*, size_t str_count, const RF_String* str)
|
static inline bool NormalizedHammingInit(RF_ScorerFunc* self, const RF_Kwargs*, size_t str_count, const RF_String* str)
|
||||||
{
|
{
|
||||||
return scorer_init_f64<string_metric::CachedNormalizedHamming>(self, str_count, str);
|
return legacy_scorer_init_f64<string_metric::CachedNormalizedHamming>(self, str_count, str);
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline double jaro_similarity_func(const RF_String& s1, const RF_String& s2, double score_cutoff)
|
static inline double jaro_similarity_func(const RF_String& s1, const RF_String& s2, double score_cutoff)
|
||||||
|
@ -61,7 +111,7 @@ static inline double jaro_similarity_func(const RF_String& s1, const RF_String&
|
||||||
}
|
}
|
||||||
static inline bool JaroSimilarityInit(RF_ScorerFunc* self, const RF_Kwargs*, size_t str_count, const RF_String* str)
|
static inline bool JaroSimilarityInit(RF_ScorerFunc* self, const RF_Kwargs*, size_t str_count, const RF_String* str)
|
||||||
{
|
{
|
||||||
return scorer_init_f64<string_metric::CachedJaroSimilarity>(self, str_count, str);
|
return legacy_scorer_init_f64<string_metric::CachedJaroSimilarity>(self, str_count, str);
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline double jaro_winkler_similarity_func(const RF_String& s1, const RF_String& s2,
|
static inline double jaro_winkler_similarity_func(const RF_String& s1, const RF_String& s2,
|
||||||
|
@ -73,7 +123,7 @@ static inline double jaro_winkler_similarity_func(const RF_String& s1, const RF_
|
||||||
}
|
}
|
||||||
static inline bool JaroWinklerSimilarityInit(RF_ScorerFunc* self, const RF_Kwargs* kwargs, size_t str_count, const RF_String* str)
|
static inline bool JaroWinklerSimilarityInit(RF_ScorerFunc* self, const RF_Kwargs* kwargs, size_t str_count, const RF_String* str)
|
||||||
{
|
{
|
||||||
return scorer_init_f64<string_metric::CachedJaroWinklerSimilarity>(self, str_count, str, *(double*)(kwargs->context));
|
return legacy_scorer_init_f64<string_metric::CachedJaroWinklerSimilarity>(self, str_count, str, *(double*)(kwargs->context));
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline std::vector<rapidfuzz::LevenshteinEditOp> levenshtein_editops_func(
|
static inline std::vector<rapidfuzz::LevenshteinEditOp> levenshtein_editops_func(
|
||||||
|
|
Loading…
Reference in New Issue