From 3b6f22ffa660c56ff38eb21531f36651e1c1c77a Mon Sep 17 00:00:00 2001 From: Max Bachmann Date: Wed, 17 Aug 2022 22:05:07 +0200 Subject: [PATCH] fix hashing for custom classes --- CHANGELOG.md | 4 ++++ src/rapidfuzz/cpp_common.pxd | 11 ++++------- tests/distance/test_Levenshtein.py | 13 +++++++++++++ 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c51966e..ceea184 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ ## Changelog +### [2.6.0] - 2022-08- +#### Fixed +- fix hashing for custom classes + ### [2.5.0] - 2022-08-14 #### Added - added support for KeyboardInterrupt in processor module diff --git a/src/rapidfuzz/cpp_common.pxd b/src/rapidfuzz/cpp_common.pxd index 2582003..752c9b0 100644 --- a/src/rapidfuzz/cpp_common.pxd +++ b/src/rapidfuzz/cpp_common.pxd @@ -238,11 +238,6 @@ cdef extern from "cpp_common.hpp": RF_String convert_string(object py_str) void validate_string(object py_str, const char* err) except + -cdef inline uint64_t rf_hash(val) except *: - if val == -1: - return -1 - return hash(val) - cdef inline RF_String hash_array(arr) except *: # TODO on Cpython this does not require any copies cdef RF_String s_proc @@ -283,7 +278,7 @@ cdef inline RF_String hash_array(arr) except *: else: # float/double are hashed s_proc.kind = RF_StringType.RF_UINT64 for i in range(s_proc.length): - (s_proc.data)[i] = rf_hash(arr[i]) + (s_proc.data)[i] = hash(arr[i]) except Exception as e: free(s_proc.data) s_proc.data = NULL @@ -309,8 +304,10 @@ cdef inline RF_String hash_sequence(seq) except *: # this is required so e.g. a list of char can be compared to a string if isinstance(elem, str) and len(elem) == 1: (s_proc.data)[i] = elem + elif isinstance(elem, int) and elem == -1: + (s_proc.data)[i] = -1 else: - (s_proc.data)[i] = rf_hash(elem) + (s_proc.data)[i] = hash(elem) except Exception as e: free(s_proc.data) s_proc.data = NULL diff --git a/tests/distance/test_Levenshtein.py b/tests/distance/test_Levenshtein.py index 8e8e03c..b70bfde 100644 --- a/tests/distance/test_Levenshtein.py +++ b/tests/distance/test_Levenshtein.py @@ -10,6 +10,18 @@ from rapidfuzz.distance import Opcodes, Opcode, Levenshtein_cpp, Levenshtein_py def isclose(a, b, rel_tol=1e-09, abs_tol=0.0): return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) +class CustomHashable: + def __init__(self, string): + self._string = string + + def __eq__(self, other): + raise NotImplementedError + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(self._string) class Levenshtein: @staticmethod @@ -61,6 +73,7 @@ def test_cross_type_matching(): # todo add support in pure python assert Levenshtein_cpp.distance("aaaa", [ord("a"), ord("a"), "a", "a"]) == 0 assert Levenshtein_cpp.distance([0, -1], [0, -2]) == 1 + assert Levenshtein_cpp.distance([CustomHashable("aa"), CustomHashable("aa")], [CustomHashable("aa"), CustomHashable("bb")]) == 1 def test_word_error_rate():