mirror of https://github.com/explosion/spaCy.git
Serialize _context with Doc
This commit is contained in:
parent
5a979137a7
commit
161f1fac91
|
@ -127,3 +127,22 @@ def test_serialize_custom_extension(en_vocab, writer_flag, reader_flag, reader_v
|
||||||
doc_2 = list(doc_bin_2.get_docs(en_vocab))[0]
|
doc_2 = list(doc_bin_2.get_docs(en_vocab))[0]
|
||||||
assert doc_2._.foo == reader_value
|
assert doc_2._.foo == reader_value
|
||||||
Underscore.doc_extensions = {}
|
Underscore.doc_extensions = {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"value",
|
||||||
|
[
|
||||||
|
("bar"),
|
||||||
|
(None),
|
||||||
|
({"a": 1, "b": 34.5}),
|
||||||
|
([1, 2, "foo"]),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_serialize_roundtrip_context(en_vocab, value):
|
||||||
|
"""Test that Doc._context is serialized with to/from_bytes."""
|
||||||
|
doc = Doc(en_vocab, words=["hello", "world"])
|
||||||
|
doc._context = value
|
||||||
|
data = doc.to_bytes()
|
||||||
|
doc2 = Doc(en_vocab)
|
||||||
|
doc2.from_bytes(data)
|
||||||
|
assert doc._context == doc2._context
|
||||||
|
|
|
@ -248,6 +248,7 @@ cdef class Doc:
|
||||||
self.tensor = numpy.zeros((0,), dtype="float32")
|
self.tensor = numpy.zeros((0,), dtype="float32")
|
||||||
self.user_data = {} if user_data is None else user_data
|
self.user_data = {} if user_data is None else user_data
|
||||||
self._vector = None
|
self._vector = None
|
||||||
|
self._context = None
|
||||||
self.noun_chunks_iterator = self.vocab.get_noun_chunks
|
self.noun_chunks_iterator = self.vocab.get_noun_chunks
|
||||||
cdef bint has_space
|
cdef bint has_space
|
||||||
if words is None and spaces is not None:
|
if words is None and spaces is not None:
|
||||||
|
@ -1331,7 +1332,8 @@ cdef class Doc:
|
||||||
"cats": lambda: self.cats,
|
"cats": lambda: self.cats,
|
||||||
"spans": lambda: self.spans.to_bytes(),
|
"spans": lambda: self.spans.to_bytes(),
|
||||||
"strings": lambda: list(strings),
|
"strings": lambda: list(strings),
|
||||||
"has_unknown_spaces": lambda: self.has_unknown_spaces
|
"has_unknown_spaces": lambda: self.has_unknown_spaces,
|
||||||
|
"_context": lambda: srsly.msgpack_dumps(self._context),
|
||||||
}
|
}
|
||||||
if "user_data" not in exclude and self.user_data:
|
if "user_data" not in exclude and self.user_data:
|
||||||
user_data_keys, user_data_values = list(zip(*self.user_data.items()))
|
user_data_keys, user_data_values = list(zip(*self.user_data.items()))
|
||||||
|
@ -1375,6 +1377,8 @@ cdef class Doc:
|
||||||
self.vocab.strings.add(s)
|
self.vocab.strings.add(s)
|
||||||
if "has_unknown_spaces" not in exclude and "has_unknown_spaces" in msg:
|
if "has_unknown_spaces" not in exclude and "has_unknown_spaces" in msg:
|
||||||
self.has_unknown_spaces = msg["has_unknown_spaces"]
|
self.has_unknown_spaces = msg["has_unknown_spaces"]
|
||||||
|
if "_context" not in exclude and "_context" in msg:
|
||||||
|
self._context = srsly.msgpack_loads(msg["_context"])
|
||||||
start = 0
|
start = 0
|
||||||
cdef const LexemeC* lex
|
cdef const LexemeC* lex
|
||||||
cdef str orth_
|
cdef str orth_
|
||||||
|
|
Loading…
Reference in New Issue