diff --git a/spacy/tests/serialize/test_serialize_doc.py b/spacy/tests/serialize/test_serialize_doc.py index 23afaf26c..03b8e37f5 100644 --- a/spacy/tests/serialize/test_serialize_doc.py +++ b/spacy/tests/serialize/test_serialize_doc.py @@ -127,3 +127,22 @@ def test_serialize_custom_extension(en_vocab, writer_flag, reader_flag, reader_v doc_2 = list(doc_bin_2.get_docs(en_vocab))[0] assert doc_2._.foo == reader_value Underscore.doc_extensions = {} + + +@pytest.mark.parametrize( + "value", + [ + ("bar"), + (None), + ({"a": 1, "b": 34.5}), + ([1, 2, "foo"]), + ], +) +def test_serialize_roundtrip_context(en_vocab, value): + """Test that Doc._context is serialized with to/from_bytes.""" + doc = Doc(en_vocab, words=["hello", "world"]) + doc._context = value + data = doc.to_bytes() + doc2 = Doc(en_vocab) + doc2.from_bytes(data) + assert doc._context == doc2._context diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 3709cece0..0b5fccf50 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -248,6 +248,7 @@ cdef class Doc: self.tensor = numpy.zeros((0,), dtype="float32") self.user_data = {} if user_data is None else user_data self._vector = None + self._context = None self.noun_chunks_iterator = self.vocab.get_noun_chunks cdef bint has_space if words is None and spaces is not None: @@ -1331,7 +1332,8 @@ cdef class Doc: "cats": lambda: self.cats, "spans": lambda: self.spans.to_bytes(), "strings": lambda: list(strings), - "has_unknown_spaces": lambda: self.has_unknown_spaces + "has_unknown_spaces": lambda: self.has_unknown_spaces, + "_context": lambda: srsly.msgpack_dumps(self._context), } if "user_data" not in exclude and self.user_data: user_data_keys, user_data_values = list(zip(*self.user_data.items())) @@ -1375,6 +1377,8 @@ cdef class Doc: self.vocab.strings.add(s) if "has_unknown_spaces" not in exclude and "has_unknown_spaces" in msg: self.has_unknown_spaces = msg["has_unknown_spaces"] + if "_context" not in exclude and "_context" in msg: + self._context = srsly.msgpack_loads(msg["_context"]) start = 0 cdef const LexemeC* lex cdef str orth_