mirror of https://github.com/explosion/spaCy.git
Add Span.kb_id/Span.id strings to Doc/DocBin serialization if set (#12493)
* Add Span.kb_id/Span.id strings to Doc/DocBin serialization if set * Format
This commit is contained in:
parent
4538ceb507
commit
4a1ec332de
|
@ -213,6 +213,13 @@ def test_serialize_doc_exclude(en_vocab):
|
|||
|
||||
def test_serialize_doc_span_groups(en_vocab):
|
||||
doc = Doc(en_vocab, words=["hello", "world", "!"])
|
||||
doc.spans["content"] = [doc[0:2]]
|
||||
span = doc[0:2]
|
||||
span.label_ = "test_serialize_doc_span_groups_label"
|
||||
span.id_ = "test_serialize_doc_span_groups_id"
|
||||
span.kb_id_ = "test_serialize_doc_span_groups_kb_id"
|
||||
doc.spans["content"] = [span]
|
||||
new_doc = Doc(en_vocab).from_bytes(doc.to_bytes())
|
||||
assert len(new_doc.spans["content"]) == 1
|
||||
assert new_doc.spans["content"][0].label_ == "test_serialize_doc_span_groups_label"
|
||||
assert new_doc.spans["content"][0].id_ == "test_serialize_doc_span_groups_id"
|
||||
assert new_doc.spans["content"][0].kb_id_ == "test_serialize_doc_span_groups_kb_id"
|
||||
|
|
|
@ -49,7 +49,11 @@ def test_serialize_doc_bin():
|
|||
nlp = English()
|
||||
for doc in nlp.pipe(texts):
|
||||
doc.cats = cats
|
||||
doc.spans["start"] = [doc[0:2]]
|
||||
span = doc[0:2]
|
||||
span.label_ = "UNUSUAL_SPAN_LABEL"
|
||||
span.id_ = "UNUSUAL_SPAN_ID"
|
||||
span.kb_id_ = "UNUSUAL_SPAN_KB_ID"
|
||||
doc.spans["start"] = [span]
|
||||
doc[0].norm_ = "UNUSUAL_TOKEN_NORM"
|
||||
doc[0].ent_id_ = "UNUSUAL_TOKEN_ENT_ID"
|
||||
doc_bin.add(doc)
|
||||
|
@ -63,6 +67,9 @@ def test_serialize_doc_bin():
|
|||
assert doc.text == texts[i]
|
||||
assert doc.cats == cats
|
||||
assert len(doc.spans) == 1
|
||||
assert doc.spans["start"][0].label_ == "UNUSUAL_SPAN_LABEL"
|
||||
assert doc.spans["start"][0].id_ == "UNUSUAL_SPAN_ID"
|
||||
assert doc.spans["start"][0].kb_id_ == "UNUSUAL_SPAN_KB_ID"
|
||||
assert doc[0].norm_ == "UNUSUAL_TOKEN_NORM"
|
||||
assert doc[0].ent_id_ == "UNUSUAL_TOKEN_ENT_ID"
|
||||
|
||||
|
|
|
@ -124,6 +124,10 @@ class DocBin:
|
|||
for key, group in doc.spans.items():
|
||||
for span in group:
|
||||
self.strings.add(span.label_)
|
||||
if span.kb_id in span.doc.vocab.strings:
|
||||
self.strings.add(span.kb_id_)
|
||||
if span.id in span.doc.vocab.strings:
|
||||
self.strings.add(span.id_)
|
||||
|
||||
def get_docs(self, vocab: Vocab) -> Iterator[Doc]:
|
||||
"""Recover Doc objects from the annotations, using the given vocab.
|
||||
|
|
|
@ -1346,6 +1346,10 @@ cdef class Doc:
|
|||
for group in self.spans.values():
|
||||
for span in group:
|
||||
strings.add(span.label_)
|
||||
if span.kb_id in span.doc.vocab.strings:
|
||||
strings.add(span.kb_id_)
|
||||
if span.id in span.doc.vocab.strings:
|
||||
strings.add(span.id_)
|
||||
# Msgpack doesn't distinguish between lists and tuples, which is
|
||||
# vexing for user data. As a best guess, we *know* that within
|
||||
# keys, we must have tuples. In values we just have to hope
|
||||
|
|
Loading…
Reference in New Issue