Refactor Doc.ents setter to use Doc.set_ents

Additional changes:

* Entity spans with missing labels are ignored
* Fix ent_kb_id setting in `Doc.set_ents`
This commit is contained in:
Adriane Boyd 2020-09-24 12:36:51 +02:00
parent b1a7d6c528
commit 8eaacaae97
3 changed files with 14 additions and 42 deletions

View File

@ -29,10 +29,10 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
ner.begin_training(lambda: [_ner_example(ner)]) ner.begin_training(lambda: [_ner_example(ner)])
ner(doc) ner(doc)
doc.ents = [(doc.vocab.strings["ANIMAL"], 3, 4)] doc.ents = [("ANIMAL", 3, 4)]
assert [w.ent_iob_ for w in doc] == ["O", "O", "O", "B"] assert [w.ent_iob_ for w in doc] == ["O", "O", "O", "B"]
doc.ents = [(doc.vocab.strings["WORD"], 0, 2)] doc.ents = [("WORD", 0, 2)]
assert [w.ent_iob_ for w in doc] == ["B", "I", "O", "O"] assert [w.ent_iob_ for w in doc] == ["B", "I", "O", "O"]

View File

@ -673,49 +673,16 @@ cdef class Doc:
# TODO: # TODO:
# 1. Test basic data-driven ORTH gazetteer # 1. Test basic data-driven ORTH gazetteer
# 2. Test more nuanced date and currency regex # 2. Test more nuanced date and currency regex
tokens_in_ents = {} cdef attr_t entity_type, kb_id
cdef attr_t entity_type cdef int ent_start, ent_end
cdef attr_t kb_id ent_spans = []
cdef int ent_start, ent_end, token_index
for ent_info in ents: for ent_info in ents:
entity_type_, kb_id, ent_start, ent_end = get_entity_info(ent_info) entity_type_, kb_id, ent_start, ent_end = get_entity_info(ent_info)
if isinstance(entity_type_, str): if isinstance(entity_type_, str):
self.vocab.strings.add(entity_type_) self.vocab.strings.add(entity_type_)
entity_type = self.vocab.strings.as_int(entity_type_) span = Span(self, ent_start, ent_end, label=entity_type_, kb_id=kb_id)
for token_index in range(ent_start, ent_end): ent_spans.append(span)
if token_index in tokens_in_ents: self.set_ents(ent_spans, default=SetEntsDefault.outside)
raise ValueError(Errors.E103.format(
span1=(tokens_in_ents[token_index][0],
tokens_in_ents[token_index][1],
self.vocab.strings[tokens_in_ents[token_index][2]]),
span2=(ent_start, ent_end, self.vocab.strings[entity_type])))
tokens_in_ents[token_index] = (ent_start, ent_end, entity_type, kb_id)
cdef int i
for i in range(self.length):
# default values
entity_type = 0
kb_id = 0
# Set ent_iob to Outside (2) by default
ent_iob = 2
# overwrite if the token was part of a specified entity
if i in tokens_in_ents.keys():
ent_start, ent_end, entity_type, kb_id = tokens_in_ents[i]
if entity_type is None or entity_type <= 0:
# Only allow labelled spans
print(i, ent_start, ent_end, entity_type)
raise ValueError(Errors.E1013)
elif ent_start == i:
# Marking the start of an entity
ent_iob = 3
else:
# Marking the inside of an entity
ent_iob = 1
self.c[i].ent_type = entity_type
self.c[i].ent_kb_id = kb_id
self.c[i].ent_iob = ent_iob
def set_ents(self, entities, *, blocked=None, missing=None, outside=None, default=SetEntsDefault.outside): def set_ents(self, entities, *, blocked=None, missing=None, outside=None, default=SetEntsDefault.outside):
"""Set entity annotation. """Set entity annotation.
@ -734,6 +701,9 @@ cdef class Doc:
if default not in SetEntsDefault.values(): if default not in SetEntsDefault.values():
raise ValueError(Errors.E1011.format(default=default, modes=", ".join(SetEntsDefault))) raise ValueError(Errors.E1011.format(default=default, modes=", ".join(SetEntsDefault)))
# Ignore spans with missing labels
entities = [ent for ent in entities if ent.label > 0]
if blocked is None: if blocked is None:
blocked = tuple() blocked = tuple()
if missing is None: if missing is None:
@ -742,6 +712,7 @@ cdef class Doc:
outside = tuple() outside = tuple()
# Find all tokens covered by spans and check that none are overlapping # Find all tokens covered by spans and check that none are overlapping
cdef int i
seen_tokens = set() seen_tokens = set()
for span in itertools.chain.from_iterable([entities, blocked, missing, outside]): for span in itertools.chain.from_iterable([entities, blocked, missing, outside]):
if not isinstance(span, Span): if not isinstance(span, Span):
@ -761,6 +732,7 @@ cdef class Doc:
else: else:
self.c[i].ent_iob = 1 self.c[i].ent_iob = 1
self.c[i].ent_type = span.label self.c[i].ent_type = span.label
self.c[i].ent_kb_id = span.kb_id
for span in blocked: for span in blocked:
for i in range(span.start, span.end): for i in range(span.start, span.end):
self.c[i].ent_iob = 3 self.c[i].ent_iob = 3