From 59c763eec171e9285b39e793baa2cfbf2ccd48d7 Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Mon, 4 Jul 2022 15:04:03 +0200 Subject: [PATCH] `StringStore`-related optimizations (#10938) * `strings`: More roubust type checking of keys/IDs, coerce `int`-like types to `hash_t` * Preserve existing public API behaviour * Fix return type * Replace `bool` with `bint`, rename to `_try_coerce_to_hash`, replace `id` with `hash` * Avoid unnecessary re-encoding and re-calculation of strings and hashs respectively * Rename variables named `hash` Add comment on early return --- spacy/strings.pxd | 2 +- spacy/strings.pyx | 135 ++++++++++++++++++++++++++++------------------ 2 files changed, 83 insertions(+), 54 deletions(-) diff --git a/spacy/strings.pxd b/spacy/strings.pxd index 370180135..5f03a9a28 100644 --- a/spacy/strings.pxd +++ b/spacy/strings.pxd @@ -26,4 +26,4 @@ cdef class StringStore: cdef public PreshMap _map cdef const Utf8Str* intern_unicode(self, str py_string) - cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length) + cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash) diff --git a/spacy/strings.pyx b/spacy/strings.pyx index 39fc441e9..c5f218342 100644 --- a/spacy/strings.pyx +++ b/spacy/strings.pyx @@ -14,6 +14,13 @@ from .symbols import NAMES as SYMBOLS_BY_INT from .errors import Errors from . import util +# Not particularly elegant, but this is faster than `isinstance(key, numbers.Integral)` +cdef inline bint _try_coerce_to_hash(object key, hash_t* out_hash): + try: + out_hash[0] = key + return True + except: + return False def get_string_id(key): """Get a string ID, handling the reserved symbols correctly. If the key is @@ -22,15 +29,27 @@ def get_string_id(key): This function optimises for convenience over performance, so shouldn't be used in tight loops. """ - if not isinstance(key, str): - return key - elif key in SYMBOLS_BY_STR: - return SYMBOLS_BY_STR[key] - elif not key: - return 0 + cdef hash_t str_hash + if isinstance(key, str): + if len(key) == 0: + return 0 + + symbol = SYMBOLS_BY_STR.get(key, None) + if symbol is not None: + return symbol + else: + chars = key.encode("utf8") + return hash_utf8(chars, len(chars)) + elif _try_coerce_to_hash(key, &str_hash): + # Coerce the integral key to the expected primitive hash type. + # This ensures that custom/overloaded "primitive" data types + # such as those implemented by numpy are not inadvertently used + # downsteam (as these are internally implemented as custom PyObjects + # whose comparison operators can incur a significant overhead). + return str_hash else: - chars = key.encode("utf8") - return hash_utf8(chars, len(chars)) + # TODO: Raise an error instead + return key cpdef hash_t hash_string(str string) except 0: @@ -110,28 +129,36 @@ cdef class StringStore: string_or_id (bytes, str or uint64): The value to encode. Returns (str / uint64): The value to be retrieved. """ - if isinstance(string_or_id, str) and len(string_or_id) == 0: - return 0 - elif string_or_id == 0: - return "" - elif string_or_id in SYMBOLS_BY_STR: - return SYMBOLS_BY_STR[string_or_id] - cdef hash_t key + cdef hash_t str_hash + cdef Utf8Str* utf8str = NULL + if isinstance(string_or_id, str): - key = hash_string(string_or_id) - return key - elif isinstance(string_or_id, bytes): - key = hash_utf8(string_or_id, len(string_or_id)) - return key - elif string_or_id < len(SYMBOLS_BY_INT): - return SYMBOLS_BY_INT[string_or_id] - else: - key = string_or_id - utf8str = self._map.get(key) - if utf8str is NULL: - raise KeyError(Errors.E018.format(hash_value=string_or_id)) + if len(string_or_id) == 0: + return 0 + + # Return early if the string is found in the symbols LUT. + symbol = SYMBOLS_BY_STR.get(string_or_id, None) + if symbol is not None: + return symbol else: - return decode_Utf8Str(utf8str) + return hash_string(string_or_id) + elif isinstance(string_or_id, bytes): + return hash_utf8(string_or_id, len(string_or_id)) + elif _try_coerce_to_hash(string_or_id, &str_hash): + if str_hash == 0: + return "" + elif str_hash < len(SYMBOLS_BY_INT): + return SYMBOLS_BY_INT[str_hash] + else: + utf8str = self._map.get(str_hash) + else: + # TODO: Raise an error instead + utf8str = self._map.get(string_or_id) + + if utf8str is NULL: + raise KeyError(Errors.E018.format(hash_value=string_or_id)) + else: + return decode_Utf8Str(utf8str) def as_int(self, key): """If key is an int, return it; otherwise, get the int value.""" @@ -153,19 +180,22 @@ cdef class StringStore: string (str): The string to add. RETURNS (uint64): The string's hash value. """ + cdef hash_t str_hash if isinstance(string, str): if string in SYMBOLS_BY_STR: return SYMBOLS_BY_STR[string] - key = hash_string(string) - self.intern_unicode(string) + + string = string.encode("utf8") + str_hash = hash_utf8(string, len(string)) + self._intern_utf8(string, len(string), &str_hash) elif isinstance(string, bytes): if string in SYMBOLS_BY_STR: return SYMBOLS_BY_STR[string] - key = hash_utf8(string, len(string)) - self._intern_utf8(string, len(string)) + str_hash = hash_utf8(string, len(string)) + self._intern_utf8(string, len(string), &str_hash) else: raise TypeError(Errors.E017.format(value_type=type(string))) - return key + return str_hash def __len__(self): """The number of strings in the store. @@ -174,30 +204,29 @@ cdef class StringStore: """ return self.keys.size() - def __contains__(self, string not None): - """Check whether a string is in the store. + def __contains__(self, string_or_id not None): + """Check whether a string or ID is in the store. - string (str): The string to check. + string_or_id (str or int): The string to check. RETURNS (bool): Whether the store contains the string. """ - cdef hash_t key - if isinstance(string, int) or isinstance(string, long): - if string == 0: + cdef hash_t str_hash + if isinstance(string_or_id, str): + if len(string_or_id) == 0: return True - key = string - elif len(string) == 0: - return True - elif string in SYMBOLS_BY_STR: - return True - elif isinstance(string, str): - key = hash_string(string) + elif string_or_id in SYMBOLS_BY_STR: + return True + str_hash = hash_string(string_or_id) + elif _try_coerce_to_hash(string_or_id, &str_hash): + pass else: - string = string.encode("utf8") - key = hash_utf8(string, len(string)) - if key < len(SYMBOLS_BY_INT): + # TODO: Raise an error instead + return self._map.get(string_or_id) is not NULL + + if str_hash < len(SYMBOLS_BY_INT): return True else: - return self._map.get(key) is not NULL + return self._map.get(str_hash) is not NULL def __iter__(self): """Iterate over the strings in the store, in order. @@ -272,13 +301,13 @@ cdef class StringStore: cdef const Utf8Str* intern_unicode(self, str py_string): # 0 means missing, but we don't bother offsetting the index. cdef bytes byte_string = py_string.encode("utf8") - return self._intern_utf8(byte_string, len(byte_string)) + return self._intern_utf8(byte_string, len(byte_string), NULL) @cython.final - cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length): + cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash): # TODO: This function's API/behaviour is an unholy mess... # 0 means missing, but we don't bother offsetting the index. - cdef hash_t key = hash_utf8(utf8_string, length) + cdef hash_t key = precalculated_hash[0] if precalculated_hash is not NULL else hash_utf8(utf8_string, length) cdef Utf8Str* value = self._map.get(key) if value is not NULL: return value