Add ENT_ID and NORM to DocBin strings (#8054)

Save strings for token attributes `ENT_ID` and `NORM` in `DocBin`
strings.
This commit is contained in:
Adriane Boyd 2021-05-17 10:06:11 +02:00 committed by GitHub
parent 82fa81d095
commit fe3a4aa846
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 1 deletions

View File

@ -64,13 +64,15 @@ def test_serialize_doc_span_groups(en_vocab):
def test_serialize_doc_bin(): def test_serialize_doc_bin():
doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"], store_user_data=True) doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE", "NORM", "ENT_ID"], store_user_data=True)
texts = ["Some text", "Lots of texts...", "..."] texts = ["Some text", "Lots of texts...", "..."]
cats = {"A": 0.5} cats = {"A": 0.5}
nlp = English() nlp = English()
for doc in nlp.pipe(texts): for doc in nlp.pipe(texts):
doc.cats = cats doc.cats = cats
doc.spans["start"] = [doc[0:2]] doc.spans["start"] = [doc[0:2]]
doc[0].norm_ = "UNUSUAL_TOKEN_NORM"
doc[0].ent_id_ = "UNUSUAL_TOKEN_ENT_ID"
doc_bin.add(doc) doc_bin.add(doc)
bytes_data = doc_bin.to_bytes() bytes_data = doc_bin.to_bytes()
@ -82,6 +84,8 @@ def test_serialize_doc_bin():
assert doc.text == texts[i] assert doc.text == texts[i]
assert doc.cats == cats assert doc.cats == cats
assert len(doc.spans) == 1 assert len(doc.spans) == 1
assert doc[0].norm_ == "UNUSUAL_TOKEN_NORM"
assert doc[0].ent_id_ == "UNUSUAL_TOKEN_ENT_ID"
def test_serialize_doc_bin_unknown_spaces(en_vocab): def test_serialize_doc_bin_unknown_spaces(en_vocab):

View File

@ -103,10 +103,12 @@ class DocBin:
self.strings.add(token.text) self.strings.add(token.text)
self.strings.add(token.tag_) self.strings.add(token.tag_)
self.strings.add(token.lemma_) self.strings.add(token.lemma_)
self.strings.add(token.norm_)
self.strings.add(str(token.morph)) self.strings.add(str(token.morph))
self.strings.add(token.dep_) self.strings.add(token.dep_)
self.strings.add(token.ent_type_) self.strings.add(token.ent_type_)
self.strings.add(token.ent_kb_id_) self.strings.add(token.ent_kb_id_)
self.strings.add(token.ent_id_)
self.cats.append(doc.cats) self.cats.append(doc.cats)
self.user_data.append(srsly.msgpack_dumps(doc.user_data)) self.user_data.append(srsly.msgpack_dumps(doc.user_data))
self.span_groups.append(doc.spans.to_bytes()) self.span_groups.append(doc.spans.to_bytes())