mirror of https://github.com/explosion/spaCy.git
`StringStore`-related optimizations (#10938)
* `strings`: More roubust type checking of keys/IDs, coerce `int`-like types to `hash_t` * Preserve existing public API behaviour * Fix return type * Replace `bool` with `bint`, rename to `_try_coerce_to_hash`, replace `id` with `hash` * Avoid unnecessary re-encoding and re-calculation of strings and hashs respectively * Rename variables named `hash` Add comment on early return
This commit is contained in:
parent
7c1bf2fa1f
commit
59c763eec1
|
@ -26,4 +26,4 @@ cdef class StringStore:
|
||||||
cdef public PreshMap _map
|
cdef public PreshMap _map
|
||||||
|
|
||||||
cdef const Utf8Str* intern_unicode(self, str py_string)
|
cdef const Utf8Str* intern_unicode(self, str py_string)
|
||||||
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length)
|
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash)
|
||||||
|
|
|
@ -14,6 +14,13 @@ from .symbols import NAMES as SYMBOLS_BY_INT
|
||||||
from .errors import Errors
|
from .errors import Errors
|
||||||
from . import util
|
from . import util
|
||||||
|
|
||||||
|
# Not particularly elegant, but this is faster than `isinstance(key, numbers.Integral)`
|
||||||
|
cdef inline bint _try_coerce_to_hash(object key, hash_t* out_hash):
|
||||||
|
try:
|
||||||
|
out_hash[0] = key
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
def get_string_id(key):
|
def get_string_id(key):
|
||||||
"""Get a string ID, handling the reserved symbols correctly. If the key is
|
"""Get a string ID, handling the reserved symbols correctly. If the key is
|
||||||
|
@ -22,15 +29,27 @@ def get_string_id(key):
|
||||||
This function optimises for convenience over performance, so shouldn't be
|
This function optimises for convenience over performance, so shouldn't be
|
||||||
used in tight loops.
|
used in tight loops.
|
||||||
"""
|
"""
|
||||||
if not isinstance(key, str):
|
cdef hash_t str_hash
|
||||||
return key
|
if isinstance(key, str):
|
||||||
elif key in SYMBOLS_BY_STR:
|
if len(key) == 0:
|
||||||
return SYMBOLS_BY_STR[key]
|
|
||||||
elif not key:
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
symbol = SYMBOLS_BY_STR.get(key, None)
|
||||||
|
if symbol is not None:
|
||||||
|
return symbol
|
||||||
else:
|
else:
|
||||||
chars = key.encode("utf8")
|
chars = key.encode("utf8")
|
||||||
return hash_utf8(chars, len(chars))
|
return hash_utf8(chars, len(chars))
|
||||||
|
elif _try_coerce_to_hash(key, &str_hash):
|
||||||
|
# Coerce the integral key to the expected primitive hash type.
|
||||||
|
# This ensures that custom/overloaded "primitive" data types
|
||||||
|
# such as those implemented by numpy are not inadvertently used
|
||||||
|
# downsteam (as these are internally implemented as custom PyObjects
|
||||||
|
# whose comparison operators can incur a significant overhead).
|
||||||
|
return str_hash
|
||||||
|
else:
|
||||||
|
# TODO: Raise an error instead
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
cpdef hash_t hash_string(str string) except 0:
|
cpdef hash_t hash_string(str string) except 0:
|
||||||
|
@ -110,24 +129,32 @@ cdef class StringStore:
|
||||||
string_or_id (bytes, str or uint64): The value to encode.
|
string_or_id (bytes, str or uint64): The value to encode.
|
||||||
Returns (str / uint64): The value to be retrieved.
|
Returns (str / uint64): The value to be retrieved.
|
||||||
"""
|
"""
|
||||||
if isinstance(string_or_id, str) and len(string_or_id) == 0:
|
cdef hash_t str_hash
|
||||||
return 0
|
cdef Utf8Str* utf8str = NULL
|
||||||
elif string_or_id == 0:
|
|
||||||
return ""
|
|
||||||
elif string_or_id in SYMBOLS_BY_STR:
|
|
||||||
return SYMBOLS_BY_STR[string_or_id]
|
|
||||||
cdef hash_t key
|
|
||||||
if isinstance(string_or_id, str):
|
if isinstance(string_or_id, str):
|
||||||
key = hash_string(string_or_id)
|
if len(string_or_id) == 0:
|
||||||
return key
|
return 0
|
||||||
elif isinstance(string_or_id, bytes):
|
|
||||||
key = hash_utf8(string_or_id, len(string_or_id))
|
# Return early if the string is found in the symbols LUT.
|
||||||
return key
|
symbol = SYMBOLS_BY_STR.get(string_or_id, None)
|
||||||
elif string_or_id < len(SYMBOLS_BY_INT):
|
if symbol is not None:
|
||||||
return SYMBOLS_BY_INT[string_or_id]
|
return symbol
|
||||||
else:
|
else:
|
||||||
key = string_or_id
|
return hash_string(string_or_id)
|
||||||
utf8str = <Utf8Str*>self._map.get(key)
|
elif isinstance(string_or_id, bytes):
|
||||||
|
return hash_utf8(string_or_id, len(string_or_id))
|
||||||
|
elif _try_coerce_to_hash(string_or_id, &str_hash):
|
||||||
|
if str_hash == 0:
|
||||||
|
return ""
|
||||||
|
elif str_hash < len(SYMBOLS_BY_INT):
|
||||||
|
return SYMBOLS_BY_INT[str_hash]
|
||||||
|
else:
|
||||||
|
utf8str = <Utf8Str*>self._map.get(str_hash)
|
||||||
|
else:
|
||||||
|
# TODO: Raise an error instead
|
||||||
|
utf8str = <Utf8Str*>self._map.get(string_or_id)
|
||||||
|
|
||||||
if utf8str is NULL:
|
if utf8str is NULL:
|
||||||
raise KeyError(Errors.E018.format(hash_value=string_or_id))
|
raise KeyError(Errors.E018.format(hash_value=string_or_id))
|
||||||
else:
|
else:
|
||||||
|
@ -153,19 +180,22 @@ cdef class StringStore:
|
||||||
string (str): The string to add.
|
string (str): The string to add.
|
||||||
RETURNS (uint64): The string's hash value.
|
RETURNS (uint64): The string's hash value.
|
||||||
"""
|
"""
|
||||||
|
cdef hash_t str_hash
|
||||||
if isinstance(string, str):
|
if isinstance(string, str):
|
||||||
if string in SYMBOLS_BY_STR:
|
if string in SYMBOLS_BY_STR:
|
||||||
return SYMBOLS_BY_STR[string]
|
return SYMBOLS_BY_STR[string]
|
||||||
key = hash_string(string)
|
|
||||||
self.intern_unicode(string)
|
string = string.encode("utf8")
|
||||||
|
str_hash = hash_utf8(string, len(string))
|
||||||
|
self._intern_utf8(string, len(string), &str_hash)
|
||||||
elif isinstance(string, bytes):
|
elif isinstance(string, bytes):
|
||||||
if string in SYMBOLS_BY_STR:
|
if string in SYMBOLS_BY_STR:
|
||||||
return SYMBOLS_BY_STR[string]
|
return SYMBOLS_BY_STR[string]
|
||||||
key = hash_utf8(string, len(string))
|
str_hash = hash_utf8(string, len(string))
|
||||||
self._intern_utf8(string, len(string))
|
self._intern_utf8(string, len(string), &str_hash)
|
||||||
else:
|
else:
|
||||||
raise TypeError(Errors.E017.format(value_type=type(string)))
|
raise TypeError(Errors.E017.format(value_type=type(string)))
|
||||||
return key
|
return str_hash
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""The number of strings in the store.
|
"""The number of strings in the store.
|
||||||
|
@ -174,30 +204,29 @@ cdef class StringStore:
|
||||||
"""
|
"""
|
||||||
return self.keys.size()
|
return self.keys.size()
|
||||||
|
|
||||||
def __contains__(self, string not None):
|
def __contains__(self, string_or_id not None):
|
||||||
"""Check whether a string is in the store.
|
"""Check whether a string or ID is in the store.
|
||||||
|
|
||||||
string (str): The string to check.
|
string_or_id (str or int): The string to check.
|
||||||
RETURNS (bool): Whether the store contains the string.
|
RETURNS (bool): Whether the store contains the string.
|
||||||
"""
|
"""
|
||||||
cdef hash_t key
|
cdef hash_t str_hash
|
||||||
if isinstance(string, int) or isinstance(string, long):
|
if isinstance(string_or_id, str):
|
||||||
if string == 0:
|
if len(string_or_id) == 0:
|
||||||
return True
|
return True
|
||||||
key = string
|
elif string_or_id in SYMBOLS_BY_STR:
|
||||||
elif len(string) == 0:
|
|
||||||
return True
|
return True
|
||||||
elif string in SYMBOLS_BY_STR:
|
str_hash = hash_string(string_or_id)
|
||||||
return True
|
elif _try_coerce_to_hash(string_or_id, &str_hash):
|
||||||
elif isinstance(string, str):
|
pass
|
||||||
key = hash_string(string)
|
|
||||||
else:
|
else:
|
||||||
string = string.encode("utf8")
|
# TODO: Raise an error instead
|
||||||
key = hash_utf8(string, len(string))
|
return self._map.get(string_or_id) is not NULL
|
||||||
if key < len(SYMBOLS_BY_INT):
|
|
||||||
|
if str_hash < len(SYMBOLS_BY_INT):
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return self._map.get(key) is not NULL
|
return self._map.get(str_hash) is not NULL
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""Iterate over the strings in the store, in order.
|
"""Iterate over the strings in the store, in order.
|
||||||
|
@ -272,13 +301,13 @@ cdef class StringStore:
|
||||||
cdef const Utf8Str* intern_unicode(self, str py_string):
|
cdef const Utf8Str* intern_unicode(self, str py_string):
|
||||||
# 0 means missing, but we don't bother offsetting the index.
|
# 0 means missing, but we don't bother offsetting the index.
|
||||||
cdef bytes byte_string = py_string.encode("utf8")
|
cdef bytes byte_string = py_string.encode("utf8")
|
||||||
return self._intern_utf8(byte_string, len(byte_string))
|
return self._intern_utf8(byte_string, len(byte_string), NULL)
|
||||||
|
|
||||||
@cython.final
|
@cython.final
|
||||||
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length):
|
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash):
|
||||||
# TODO: This function's API/behaviour is an unholy mess...
|
# TODO: This function's API/behaviour is an unholy mess...
|
||||||
# 0 means missing, but we don't bother offsetting the index.
|
# 0 means missing, but we don't bother offsetting the index.
|
||||||
cdef hash_t key = hash_utf8(utf8_string, length)
|
cdef hash_t key = precalculated_hash[0] if precalculated_hash is not NULL else hash_utf8(utf8_string, length)
|
||||||
cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
|
cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
|
||||||
if value is not NULL:
|
if value is not NULL:
|
||||||
return value
|
return value
|
||||||
|
|
Loading…
Reference in New Issue