diff --git a/spacy/utf8string.pxd b/spacy/utf8string.pxd index 82ae50022..16488b899 100644 --- a/spacy/utf8string.pxd +++ b/spacy/utf8string.pxd @@ -13,7 +13,7 @@ cdef struct Utf8Str: cdef class StringStore: cdef Pool mem - cdef PreshMap table + cdef PreshMap _map cdef Utf8Str* strings cdef int size cdef int _resize_at diff --git a/spacy/utf8string.pyx b/spacy/utf8string.pyx index 18d4a4e5e..426b531f4 100644 --- a/spacy/utf8string.pyx +++ b/spacy/utf8string.pyx @@ -8,7 +8,7 @@ SEPARATOR = '\n|-SEP-|\n' cdef class StringStore: def __init__(self): self.mem = Pool() - self.table = PreshMap() + self._map = PreshMap() self._resize_at = 10000 self.strings = self.mem.alloc(self._resize_at, sizeof(Utf8Str)) self.size = 1 @@ -17,17 +17,21 @@ cdef class StringStore: def __get__(self): return self.size-1 - def __getitem__(self, string_or_id): + def __getitem__(self, object string_or_id): cdef bytes byte_string cdef Utf8Str* utf8str - if type(string_or_id) == int or type(string_or_id) == long: + if isinstance(string_or_id, int): if string_or_id < 1 or string_or_id >= self.size: raise IndexError(string_or_id) utf8str = &self.strings[string_or_id] return utf8str.chars[:utf8str.length] - elif type(string_or_id) == bytes: + elif isinstance(string_or_id, bytes): utf8str = self.intern(string_or_id, len(string_or_id)) return utf8str.i + elif isinstance(string_or_id, unicode): + byte_string = string_or_id.encode('utf8') + utf8str = self.intern(byte_string, len(byte_string)) + return utf8str.i else: raise TypeError(type(string_or_id)) @@ -36,7 +40,7 @@ cdef class StringStore: # slot 0 to simplify the code, because it doesn't matter. assert length != 0 cdef hash_t key = hash64(chars, length * sizeof(char), 0) - cdef void* value = self.table.get(key) + cdef void* value = self._map.get(key) cdef size_t i if value == NULL: if self.size == self._resize_at: @@ -48,7 +52,7 @@ cdef class StringStore: self.strings[i].chars = self.mem.alloc(length, sizeof(char)) memcpy(self.strings[i].chars, chars, length) self.strings[i].length = length - self.table.set(key, self.size) + self._map.set(key, self.size) self.size += 1 else: i = value diff --git a/tests/test_intern.py b/tests/test_intern.py index 63b4b3433..a7a801b05 100644 --- a/tests/test_intern.py +++ b/tests/test_intern.py @@ -19,8 +19,12 @@ def test_save_bytes(sstore): def test_save_unicode(sstore): - with pytest.raises(TypeError): - A_i = sstore['A'] + Hello_i = sstore[u'Hello'] + assert Hello_i == 1 + assert sstore[u'Hello'] == 1 + assert sstore[u'goodbye'] != Hello_i + assert sstore[u'hello'] != Hello_i + assert Hello_i == 1 def test_zero_id(sstore):