From 4a1ec332de71dedab605eae74aadf3f52ddb955d Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Mon, 3 Apr 2023 15:11:12 +0200 Subject: [PATCH] Add Span.kb_id/Span.id strings to Doc/DocBin serialization if set (#12493) * Add Span.kb_id/Span.id strings to Doc/DocBin serialization if set * Format --- spacy/tests/serialize/test_serialize_doc.py | 9 ++++++++- spacy/tests/serialize/test_serialize_docbin.py | 9 ++++++++- spacy/tokens/_serialize.py | 4 ++++ spacy/tokens/doc.pyx | 4 ++++ 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/spacy/tests/serialize/test_serialize_doc.py b/spacy/tests/serialize/test_serialize_doc.py index 15bf67bfd..eea13445e 100644 --- a/spacy/tests/serialize/test_serialize_doc.py +++ b/spacy/tests/serialize/test_serialize_doc.py @@ -213,6 +213,13 @@ def test_serialize_doc_exclude(en_vocab): def test_serialize_doc_span_groups(en_vocab): doc = Doc(en_vocab, words=["hello", "world", "!"]) - doc.spans["content"] = [doc[0:2]] + span = doc[0:2] + span.label_ = "test_serialize_doc_span_groups_label" + span.id_ = "test_serialize_doc_span_groups_id" + span.kb_id_ = "test_serialize_doc_span_groups_kb_id" + doc.spans["content"] = [span] new_doc = Doc(en_vocab).from_bytes(doc.to_bytes()) assert len(new_doc.spans["content"]) == 1 + assert new_doc.spans["content"][0].label_ == "test_serialize_doc_span_groups_label" + assert new_doc.spans["content"][0].id_ == "test_serialize_doc_span_groups_id" + assert new_doc.spans["content"][0].kb_id_ == "test_serialize_doc_span_groups_kb_id" diff --git a/spacy/tests/serialize/test_serialize_docbin.py b/spacy/tests/serialize/test_serialize_docbin.py index 9f8e5e06b..6f7b1001c 100644 --- a/spacy/tests/serialize/test_serialize_docbin.py +++ b/spacy/tests/serialize/test_serialize_docbin.py @@ -49,7 +49,11 @@ def test_serialize_doc_bin(): nlp = English() for doc in nlp.pipe(texts): doc.cats = cats - doc.spans["start"] = [doc[0:2]] + span = doc[0:2] + span.label_ = "UNUSUAL_SPAN_LABEL" + span.id_ = "UNUSUAL_SPAN_ID" + span.kb_id_ = "UNUSUAL_SPAN_KB_ID" + doc.spans["start"] = [span] doc[0].norm_ = "UNUSUAL_TOKEN_NORM" doc[0].ent_id_ = "UNUSUAL_TOKEN_ENT_ID" doc_bin.add(doc) @@ -63,6 +67,9 @@ def test_serialize_doc_bin(): assert doc.text == texts[i] assert doc.cats == cats assert len(doc.spans) == 1 + assert doc.spans["start"][0].label_ == "UNUSUAL_SPAN_LABEL" + assert doc.spans["start"][0].id_ == "UNUSUAL_SPAN_ID" + assert doc.spans["start"][0].kb_id_ == "UNUSUAL_SPAN_KB_ID" assert doc[0].norm_ == "UNUSUAL_TOKEN_NORM" assert doc[0].ent_id_ == "UNUSUAL_TOKEN_ENT_ID" diff --git a/spacy/tokens/_serialize.py b/spacy/tokens/_serialize.py index c4e8f26f4..73c857d1f 100644 --- a/spacy/tokens/_serialize.py +++ b/spacy/tokens/_serialize.py @@ -124,6 +124,10 @@ class DocBin: for key, group in doc.spans.items(): for span in group: self.strings.add(span.label_) + if span.kb_id in span.doc.vocab.strings: + self.strings.add(span.kb_id_) + if span.id in span.doc.vocab.strings: + self.strings.add(span.id_) def get_docs(self, vocab: Vocab) -> Iterator[Doc]: """Recover Doc objects from the annotations, using the given vocab. diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 3bc404dd0..a54b4ad3c 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -1346,6 +1346,10 @@ cdef class Doc: for group in self.spans.values(): for span in group: strings.add(span.label_) + if span.kb_id in span.doc.vocab.strings: + strings.add(span.kb_id_) + if span.id in span.doc.vocab.strings: + strings.add(span.id_) # Msgpack doesn't distinguish between lists and tuples, which is # vexing for user data. As a best guess, we *know* that within # keys, we must have tuples. In values we just have to hope