mirror of https://github.com/explosion/spaCy.git
Workaround Issue #285: Allow the StringStore to be 'frozen', in which case strings will be pushed into an OOV map. We can then flush this OOV map, freeing all of the OOV strings.
This commit is contained in:
parent
d3a617aa99
commit
d8134817ff
|
@ -1,3 +1,4 @@
|
|||
# cython: infer_types=True
|
||||
from __future__ import unicode_literals, absolute_import
|
||||
|
||||
cimport cython
|
||||
|
@ -71,12 +72,14 @@ cdef Utf8Str _allocate(Pool mem, const unsigned char* chars, int length) except
|
|||
|
||||
cdef class StringStore:
|
||||
'''Map strings to and from integer IDs.'''
|
||||
def __init__(self, strings=None):
|
||||
def __init__(self, strings=None, freeze=False):
|
||||
self.mem = Pool()
|
||||
self._map = PreshMap()
|
||||
self._oov = PreshMap()
|
||||
self._resize_at = 10000
|
||||
self.c = <Utf8Str*>self.mem.alloc(self._resize_at, sizeof(Utf8Str))
|
||||
self.size = 1
|
||||
self.is_frozen = False
|
||||
if strings is not None:
|
||||
for string in strings:
|
||||
_ = self[string]
|
||||
|
@ -89,33 +92,37 @@ cdef class StringStore:
|
|||
return self.size-1
|
||||
|
||||
def __getitem__(self, object string_or_id):
|
||||
if isinstance(string_or_id, basestring) and len(string_or_id) == 0:
|
||||
return 0
|
||||
elif string_or_id == 0:
|
||||
return u''
|
||||
|
||||
cdef bytes byte_string
|
||||
cdef const Utf8Str* utf8str
|
||||
cdef unsigned int int_id
|
||||
|
||||
cdef uint64_t int_id
|
||||
if isinstance(string_or_id, (int, long)):
|
||||
try:
|
||||
int_id = string_or_id
|
||||
except OverflowError:
|
||||
raise IndexError(string_or_id)
|
||||
if int_id == 0:
|
||||
return u''
|
||||
elif int_id >= <uint64_t>self.size:
|
||||
raise IndexError(string_or_id)
|
||||
utf8str = &self.c[int_id]
|
||||
return _decode(utf8str)
|
||||
elif isinstance(string_or_id, bytes):
|
||||
byte_string = <bytes>string_or_id
|
||||
if len(byte_string) == 0:
|
||||
return 0
|
||||
int_id = string_or_id
|
||||
if int_id < <uint64_t>self.size:
|
||||
return _decode(&self.c[int_id])
|
||||
else:
|
||||
utf8str = <Utf8Str*>self._oov.get(int_id)
|
||||
if utf8str is not NULL:
|
||||
return _decode(utf8str)
|
||||
else:
|
||||
raise IndexError(string_or_id)
|
||||
elif isinstance(string_or_id, basestring):
|
||||
if isinstance(string_or_id, bytes):
|
||||
byte_string = <bytes>string_or_id
|
||||
else:
|
||||
byte_string = (<unicode>string_or_id).encode('utf8')
|
||||
utf8str = self._intern_utf8(byte_string, len(byte_string))
|
||||
return utf8str - self.c
|
||||
elif isinstance(string_or_id, unicode):
|
||||
if len(<unicode>string_or_id) == 0:
|
||||
return 0
|
||||
byte_string = (<unicode>string_or_id).encode('utf8')
|
||||
utf8str = self._intern_utf8(byte_string, len(byte_string))
|
||||
return utf8str - self.c
|
||||
if utf8str is NULL:
|
||||
# TODO: We could get unlucky here, and hash into a value that
|
||||
# collides with the 'real' strings. All we have to do is offset
|
||||
# I think?
|
||||
return _hash_utf8(byte_string, len(byte_string))
|
||||
else:
|
||||
return utf8str - self.c
|
||||
else:
|
||||
raise TypeError(type(string_or_id))
|
||||
|
||||
|
@ -129,6 +136,7 @@ cdef class StringStore:
|
|||
cdef int i
|
||||
for i in range(self.size):
|
||||
yield _decode(&self.c[i]) if i > 0 else u''
|
||||
# TODO: Iterate OOV here?
|
||||
|
||||
def __reduce__(self):
|
||||
strings = [""]
|
||||
|
@ -138,18 +146,36 @@ cdef class StringStore:
|
|||
strings.append(py_string)
|
||||
return (StringStore, (strings,), None, None, None)
|
||||
|
||||
cdef const Utf8Str* intern(self, unicode py_string) except NULL:
|
||||
def set_frozen(self, bint is_frozen):
|
||||
# TODO
|
||||
self.is_frozen = is_frozen
|
||||
|
||||
def flush_oov(self):
|
||||
self._oov = PreshMap()
|
||||
|
||||
cdef const Utf8Str* intern_unicode(self, unicode py_string):
|
||||
# 0 means missing, but we don't bother offsetting the index.
|
||||
cdef bytes byte_string = py_string.encode('utf8')
|
||||
return self._intern_utf8(byte_string, len(byte_string))
|
||||
|
||||
@cython.final
|
||||
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length) except NULL:
|
||||
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length):
|
||||
# TODO: This function's API/behaviour is an unholy mess...
|
||||
# 0 means missing, but we don't bother offsetting the index.
|
||||
cdef hash_t key = _hash_utf8(utf8_string, length)
|
||||
value = <Utf8Str*>self._map.get(key)
|
||||
cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
|
||||
if value is not NULL:
|
||||
return value
|
||||
value = <Utf8Str*>self._oov.get(key)
|
||||
if value is not NULL:
|
||||
return value
|
||||
if self.is_frozen:
|
||||
# Important: Make the OOV store own the memory. That way it's trivial
|
||||
# to flush them all.
|
||||
value = <Utf8Str*>self._oov.mem.alloc(1, sizeof(Utf8Str))
|
||||
value[0] = _allocate(self._oov.mem, <unsigned char*>utf8_string, length)
|
||||
self._oov.set(key, value)
|
||||
return NULL
|
||||
|
||||
if self.size == self._resize_at:
|
||||
self._realloc()
|
||||
|
@ -162,6 +188,7 @@ cdef class StringStore:
|
|||
string_data = json.dumps(list(self))
|
||||
if not isinstance(string_data, unicode):
|
||||
string_data = string_data.decode('utf8')
|
||||
# TODO: OOV?
|
||||
file_.write(string_data)
|
||||
|
||||
def load(self, file_):
|
||||
|
@ -173,7 +200,7 @@ cdef class StringStore:
|
|||
# explicit None/len check instead of simple truth testing
|
||||
# (bug in Cython <= 0.23.4)
|
||||
if string is not None and len(string):
|
||||
self.intern(string)
|
||||
self.intern_unicode(string)
|
||||
|
||||
def _realloc(self):
|
||||
# We want to map straight to pointers, but they'll be invalidated if
|
||||
|
|
Loading…
Reference in New Issue