mirror of https://github.com/explosion/spaCy.git
Implement StringStore serialization, and update tests
This commit is contained in:
parent
aae97f00e9
commit
d8bb5bb959
|
@ -7,9 +7,12 @@ from libc.string cimport memcpy
|
||||||
from libc.stdint cimport uint64_t, uint32_t
|
from libc.stdint cimport uint64_t, uint32_t
|
||||||
from murmurhash.mrmr cimport hash64, hash32
|
from murmurhash.mrmr cimport hash64, hash32
|
||||||
from preshed.maps cimport map_iter, key_t
|
from preshed.maps cimport map_iter, key_t
|
||||||
|
from libc.stdint cimport uint32_t
|
||||||
|
import ujson
|
||||||
|
import dill
|
||||||
|
|
||||||
from .typedefs cimport hash_t
|
from .typedefs cimport hash_t
|
||||||
from libc.stdint cimport uint32_t
|
from . import util
|
||||||
|
|
||||||
|
|
||||||
cpdef hash_t hash_string(unicode string) except 0:
|
cpdef hash_t hash_string(unicode string) except 0:
|
||||||
|
@ -92,14 +95,6 @@ cdef class StringStore:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
return self.size -1
|
return self.size -1
|
||||||
|
|
||||||
def __reduce__(self):
|
|
||||||
# TODO: OOV words, for the is_frozen stuff?
|
|
||||||
if self.is_frozen:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Currently missing support for pickling StringStore when "
|
|
||||||
"is_frozen=True")
|
|
||||||
return (StringStore, (list(self),))
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""The number of strings in the store.
|
"""The number of strings in the store.
|
||||||
|
|
||||||
|
@ -186,7 +181,10 @@ cdef class StringStore:
|
||||||
path (unicode or Path): A path to a directory, which will be created if
|
path (unicode or Path): A path to a directory, which will be created if
|
||||||
it doesn't exist. Paths may be either strings or `Path`-like objects.
|
it doesn't exist. Paths may be either strings or `Path`-like objects.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
path = util.ensure_path(path)
|
||||||
|
strings = list(self)
|
||||||
|
with path.open('w') as file_:
|
||||||
|
ujson.dump(strings, file_)
|
||||||
|
|
||||||
def from_disk(self, path):
|
def from_disk(self, path):
|
||||||
"""Loads state from a directory. Modifies the object in place and
|
"""Loads state from a directory. Modifies the object in place and
|
||||||
|
@ -196,7 +194,11 @@ cdef class StringStore:
|
||||||
strings or `Path`-like objects.
|
strings or `Path`-like objects.
|
||||||
RETURNS (StringStore): The modified `StringStore` object.
|
RETURNS (StringStore): The modified `StringStore` object.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
path = util.ensure_path(path)
|
||||||
|
with path.open('r') as file_:
|
||||||
|
strings = ujson.load(file_)
|
||||||
|
self._reset_and_load(strings)
|
||||||
|
return self
|
||||||
|
|
||||||
def to_bytes(self, **exclude):
|
def to_bytes(self, **exclude):
|
||||||
"""Serialize the current state to a binary string.
|
"""Serialize the current state to a binary string.
|
||||||
|
@ -204,7 +206,7 @@ cdef class StringStore:
|
||||||
**exclude: Named attributes to prevent from being serialized.
|
**exclude: Named attributes to prevent from being serialized.
|
||||||
RETURNS (bytes): The serialized form of the `StringStore` object.
|
RETURNS (bytes): The serialized form of the `StringStore` object.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
return ujson.dumps(list(self))
|
||||||
|
|
||||||
def from_bytes(self, bytes_data, **exclude):
|
def from_bytes(self, bytes_data, **exclude):
|
||||||
"""Load state from a binary string.
|
"""Load state from a binary string.
|
||||||
|
@ -213,7 +215,9 @@ cdef class StringStore:
|
||||||
**exclude: Named attributes to prevent from being loaded.
|
**exclude: Named attributes to prevent from being loaded.
|
||||||
RETURNS (StringStore): The `StringStore` object.
|
RETURNS (StringStore): The `StringStore` object.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
strings = ujson.loads(bytes_data)
|
||||||
|
self._reset_and_load(strings)
|
||||||
|
return self
|
||||||
|
|
||||||
def set_frozen(self, bint is_frozen):
|
def set_frozen(self, bint is_frozen):
|
||||||
# TODO
|
# TODO
|
||||||
|
@ -222,6 +226,17 @@ cdef class StringStore:
|
||||||
def flush_oov(self):
|
def flush_oov(self):
|
||||||
self._oov = PreshMap()
|
self._oov = PreshMap()
|
||||||
|
|
||||||
|
def _reset_and_load(self, strings, freeze=False):
|
||||||
|
self.mem = Pool()
|
||||||
|
self._map = PreshMap()
|
||||||
|
self._oov = PreshMap()
|
||||||
|
self._resize_at = 10000
|
||||||
|
self.c = <Utf8Str*>self.mem.alloc(self._resize_at, sizeof(Utf8Str))
|
||||||
|
self.size = 1
|
||||||
|
for string in strings:
|
||||||
|
_ = self[string]
|
||||||
|
self.is_frozen = freeze
|
||||||
|
|
||||||
cdef const Utf8Str* intern_unicode(self, unicode py_string):
|
cdef const Utf8Str* intern_unicode(self, unicode py_string):
|
||||||
# 0 means missing, but we don't bother offsetting the index.
|
# 0 means missing, but we don't bother offsetting the index.
|
||||||
cdef bytes byte_string = py_string.encode('utf8')
|
cdef bytes byte_string = py_string.encode('utf8')
|
||||||
|
|
|
@ -69,10 +69,8 @@ def test_stringstore_massive_strings(stringstore):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('text', ["qqqqq"])
|
@pytest.mark.parametrize('text', ["qqqqq"])
|
||||||
def test_stringstore_dump_load(stringstore, text_file, text):
|
def test_stringstore_to_bytes(stringstore, text):
|
||||||
store = stringstore[text]
|
store = stringstore[text]
|
||||||
stringstore.dump(text_file)
|
serialized = stringstore.to_bytes()
|
||||||
text_file.seek(0)
|
new_stringstore = StringStore().from_bytes(serialized)
|
||||||
new_stringstore = StringStore()
|
|
||||||
new_stringstore.load(text_file)
|
|
||||||
assert new_stringstore[store] == text
|
assert new_stringstore[store] == text
|
||||||
|
|
Loading…
Reference in New Issue