diff --git a/spacy/strings.pyx b/spacy/strings.pyx index e993f1423..b704ac789 100644 --- a/spacy/strings.pyx +++ b/spacy/strings.pyx @@ -7,9 +7,12 @@ from libc.string cimport memcpy from libc.stdint cimport uint64_t, uint32_t from murmurhash.mrmr cimport hash64, hash32 from preshed.maps cimport map_iter, key_t +from libc.stdint cimport uint32_t +import ujson +import dill from .typedefs cimport hash_t -from libc.stdint cimport uint32_t +from . import util cpdef hash_t hash_string(unicode string) except 0: @@ -92,14 +95,6 @@ cdef class StringStore: def __get__(self): return self.size -1 - def __reduce__(self): - # TODO: OOV words, for the is_frozen stuff? - if self.is_frozen: - raise NotImplementedError( - "Currently missing support for pickling StringStore when " - "is_frozen=True") - return (StringStore, (list(self),)) - def __len__(self): """The number of strings in the store. @@ -186,7 +181,10 @@ cdef class StringStore: path (unicode or Path): A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. """ - raise NotImplementedError() + path = util.ensure_path(path) + strings = list(self) + with path.open('w') as file_: + ujson.dump(strings, file_) def from_disk(self, path): """Loads state from a directory. Modifies the object in place and @@ -196,7 +194,11 @@ cdef class StringStore: strings or `Path`-like objects. RETURNS (StringStore): The modified `StringStore` object. """ - raise NotImplementedError() + path = util.ensure_path(path) + with path.open('r') as file_: + strings = ujson.load(file_) + self._reset_and_load(strings) + return self def to_bytes(self, **exclude): """Serialize the current state to a binary string. @@ -204,7 +206,7 @@ cdef class StringStore: **exclude: Named attributes to prevent from being serialized. RETURNS (bytes): The serialized form of the `StringStore` object. """ - raise NotImplementedError() + return ujson.dumps(list(self)) def from_bytes(self, bytes_data, **exclude): """Load state from a binary string. @@ -213,7 +215,9 @@ cdef class StringStore: **exclude: Named attributes to prevent from being loaded. RETURNS (StringStore): The `StringStore` object. """ - raise NotImplementedError() + strings = ujson.loads(bytes_data) + self._reset_and_load(strings) + return self def set_frozen(self, bint is_frozen): # TODO @@ -222,6 +226,17 @@ cdef class StringStore: def flush_oov(self): self._oov = PreshMap() + def _reset_and_load(self, strings, freeze=False): + self.mem = Pool() + self._map = PreshMap() + self._oov = PreshMap() + self._resize_at = 10000 + self.c = self.mem.alloc(self._resize_at, sizeof(Utf8Str)) + self.size = 1 + for string in strings: + _ = self[string] + self.is_frozen = freeze + cdef const Utf8Str* intern_unicode(self, unicode py_string): # 0 means missing, but we don't bother offsetting the index. cdef bytes byte_string = py_string.encode('utf8') diff --git a/spacy/tests/stringstore/test_stringstore.py b/spacy/tests/stringstore/test_stringstore.py index ebbec01d9..e3c94e33b 100644 --- a/spacy/tests/stringstore/test_stringstore.py +++ b/spacy/tests/stringstore/test_stringstore.py @@ -69,10 +69,8 @@ def test_stringstore_massive_strings(stringstore): @pytest.mark.parametrize('text', ["qqqqq"]) -def test_stringstore_dump_load(stringstore, text_file, text): +def test_stringstore_to_bytes(stringstore, text): store = stringstore[text] - stringstore.dump(text_file) - text_file.seek(0) - new_stringstore = StringStore() - new_stringstore.load(text_file) + serialized = stringstore.to_bytes() + new_stringstore = StringStore().from_bytes(serialized) assert new_stringstore[store] == text