mirror of https://github.com/explosion/spaCy.git
Serialize _context with Doc
This commit is contained in:
parent
5a979137a7
commit
161f1fac91
|
@ -127,3 +127,22 @@ def test_serialize_custom_extension(en_vocab, writer_flag, reader_flag, reader_v
|
|||
doc_2 = list(doc_bin_2.get_docs(en_vocab))[0]
|
||||
assert doc_2._.foo == reader_value
|
||||
Underscore.doc_extensions = {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value",
|
||||
[
|
||||
("bar"),
|
||||
(None),
|
||||
({"a": 1, "b": 34.5}),
|
||||
([1, 2, "foo"]),
|
||||
],
|
||||
)
|
||||
def test_serialize_roundtrip_context(en_vocab, value):
|
||||
"""Test that Doc._context is serialized with to/from_bytes."""
|
||||
doc = Doc(en_vocab, words=["hello", "world"])
|
||||
doc._context = value
|
||||
data = doc.to_bytes()
|
||||
doc2 = Doc(en_vocab)
|
||||
doc2.from_bytes(data)
|
||||
assert doc._context == doc2._context
|
||||
|
|
|
@ -248,6 +248,7 @@ cdef class Doc:
|
|||
self.tensor = numpy.zeros((0,), dtype="float32")
|
||||
self.user_data = {} if user_data is None else user_data
|
||||
self._vector = None
|
||||
self._context = None
|
||||
self.noun_chunks_iterator = self.vocab.get_noun_chunks
|
||||
cdef bint has_space
|
||||
if words is None and spaces is not None:
|
||||
|
@ -1331,7 +1332,8 @@ cdef class Doc:
|
|||
"cats": lambda: self.cats,
|
||||
"spans": lambda: self.spans.to_bytes(),
|
||||
"strings": lambda: list(strings),
|
||||
"has_unknown_spaces": lambda: self.has_unknown_spaces
|
||||
"has_unknown_spaces": lambda: self.has_unknown_spaces,
|
||||
"_context": lambda: srsly.msgpack_dumps(self._context),
|
||||
}
|
||||
if "user_data" not in exclude and self.user_data:
|
||||
user_data_keys, user_data_values = list(zip(*self.user_data.items()))
|
||||
|
@ -1375,6 +1377,8 @@ cdef class Doc:
|
|||
self.vocab.strings.add(s)
|
||||
if "has_unknown_spaces" not in exclude and "has_unknown_spaces" in msg:
|
||||
self.has_unknown_spaces = msg["has_unknown_spaces"]
|
||||
if "_context" not in exclude and "_context" in msg:
|
||||
self._context = srsly.msgpack_loads(msg["_context"])
|
||||
start = 0
|
||||
cdef const LexemeC* lex
|
||||
cdef str orth_
|
||||
|
|
Loading…
Reference in New Issue