Serialize _context with Doc

This commit is contained in:
Adriane Boyd 2021-11-02 16:09:46 +01:00
parent 5a979137a7
commit 161f1fac91
2 changed files with 24 additions and 1 deletions

View File

@ -127,3 +127,22 @@ def test_serialize_custom_extension(en_vocab, writer_flag, reader_flag, reader_v
doc_2 = list(doc_bin_2.get_docs(en_vocab))[0] doc_2 = list(doc_bin_2.get_docs(en_vocab))[0]
assert doc_2._.foo == reader_value assert doc_2._.foo == reader_value
Underscore.doc_extensions = {} Underscore.doc_extensions = {}
@pytest.mark.parametrize(
"value",
[
("bar"),
(None),
({"a": 1, "b": 34.5}),
([1, 2, "foo"]),
],
)
def test_serialize_roundtrip_context(en_vocab, value):
"""Test that Doc._context is serialized with to/from_bytes."""
doc = Doc(en_vocab, words=["hello", "world"])
doc._context = value
data = doc.to_bytes()
doc2 = Doc(en_vocab)
doc2.from_bytes(data)
assert doc._context == doc2._context

View File

@ -248,6 +248,7 @@ cdef class Doc:
self.tensor = numpy.zeros((0,), dtype="float32") self.tensor = numpy.zeros((0,), dtype="float32")
self.user_data = {} if user_data is None else user_data self.user_data = {} if user_data is None else user_data
self._vector = None self._vector = None
self._context = None
self.noun_chunks_iterator = self.vocab.get_noun_chunks self.noun_chunks_iterator = self.vocab.get_noun_chunks
cdef bint has_space cdef bint has_space
if words is None and spaces is not None: if words is None and spaces is not None:
@ -1331,7 +1332,8 @@ cdef class Doc:
"cats": lambda: self.cats, "cats": lambda: self.cats,
"spans": lambda: self.spans.to_bytes(), "spans": lambda: self.spans.to_bytes(),
"strings": lambda: list(strings), "strings": lambda: list(strings),
"has_unknown_spaces": lambda: self.has_unknown_spaces "has_unknown_spaces": lambda: self.has_unknown_spaces,
"_context": lambda: srsly.msgpack_dumps(self._context),
} }
if "user_data" not in exclude and self.user_data: if "user_data" not in exclude and self.user_data:
user_data_keys, user_data_values = list(zip(*self.user_data.items())) user_data_keys, user_data_values = list(zip(*self.user_data.items()))
@ -1375,6 +1377,8 @@ cdef class Doc:
self.vocab.strings.add(s) self.vocab.strings.add(s)
if "has_unknown_spaces" not in exclude and "has_unknown_spaces" in msg: if "has_unknown_spaces" not in exclude and "has_unknown_spaces" in msg:
self.has_unknown_spaces = msg["has_unknown_spaces"] self.has_unknown_spaces = msg["has_unknown_spaces"]
if "_context" not in exclude and "_context" in msg:
self._context = srsly.msgpack_loads(msg["_context"])
start = 0 start = 0
cdef const LexemeC* lex cdef const LexemeC* lex
cdef str orth_ cdef str orth_