mirror of https://github.com/explosion/spaCy.git
Refactor Doc.ents setter to use Doc.set_ents
Additional changes: * Entity spans with missing labels are ignored * Fix ent_kb_id setting in `Doc.set_ents`
This commit is contained in:
parent
b1a7d6c528
commit
8eaacaae97
|
@ -29,10 +29,10 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
|
|||
ner.begin_training(lambda: [_ner_example(ner)])
|
||||
ner(doc)
|
||||
|
||||
doc.ents = [(doc.vocab.strings["ANIMAL"], 3, 4)]
|
||||
doc.ents = [("ANIMAL", 3, 4)]
|
||||
assert [w.ent_iob_ for w in doc] == ["O", "O", "O", "B"]
|
||||
|
||||
doc.ents = [(doc.vocab.strings["WORD"], 0, 2)]
|
||||
doc.ents = [("WORD", 0, 2)]
|
||||
assert [w.ent_iob_ for w in doc] == ["B", "I", "O", "O"]
|
||||
|
||||
|
||||
|
|
|
@ -534,4 +534,4 @@ def test_doc_ents_setter():
|
|||
vocab = Vocab()
|
||||
ents = [("HELLO", 0, 2), (vocab.strings.add("WORLD"), 3, 5)]
|
||||
doc = Doc(vocab, words=words, ents=ents)
|
||||
assert [e.label_ for e in doc.ents] == ["HELLO", "WORLD"]
|
||||
assert [e.label_ for e in doc.ents] == ["HELLO", "WORLD"]
|
||||
|
|
|
@ -673,49 +673,16 @@ cdef class Doc:
|
|||
# TODO:
|
||||
# 1. Test basic data-driven ORTH gazetteer
|
||||
# 2. Test more nuanced date and currency regex
|
||||
tokens_in_ents = {}
|
||||
cdef attr_t entity_type
|
||||
cdef attr_t kb_id
|
||||
cdef int ent_start, ent_end, token_index
|
||||
cdef attr_t entity_type, kb_id
|
||||
cdef int ent_start, ent_end
|
||||
ent_spans = []
|
||||
for ent_info in ents:
|
||||
entity_type_, kb_id, ent_start, ent_end = get_entity_info(ent_info)
|
||||
if isinstance(entity_type_, str):
|
||||
self.vocab.strings.add(entity_type_)
|
||||
entity_type = self.vocab.strings.as_int(entity_type_)
|
||||
for token_index in range(ent_start, ent_end):
|
||||
if token_index in tokens_in_ents:
|
||||
raise ValueError(Errors.E103.format(
|
||||
span1=(tokens_in_ents[token_index][0],
|
||||
tokens_in_ents[token_index][1],
|
||||
self.vocab.strings[tokens_in_ents[token_index][2]]),
|
||||
span2=(ent_start, ent_end, self.vocab.strings[entity_type])))
|
||||
tokens_in_ents[token_index] = (ent_start, ent_end, entity_type, kb_id)
|
||||
cdef int i
|
||||
for i in range(self.length):
|
||||
# default values
|
||||
entity_type = 0
|
||||
kb_id = 0
|
||||
|
||||
# Set ent_iob to Outside (2) by default
|
||||
ent_iob = 2
|
||||
|
||||
# overwrite if the token was part of a specified entity
|
||||
if i in tokens_in_ents.keys():
|
||||
ent_start, ent_end, entity_type, kb_id = tokens_in_ents[i]
|
||||
if entity_type is None or entity_type <= 0:
|
||||
# Only allow labelled spans
|
||||
print(i, ent_start, ent_end, entity_type)
|
||||
raise ValueError(Errors.E1013)
|
||||
elif ent_start == i:
|
||||
# Marking the start of an entity
|
||||
ent_iob = 3
|
||||
else:
|
||||
# Marking the inside of an entity
|
||||
ent_iob = 1
|
||||
|
||||
self.c[i].ent_type = entity_type
|
||||
self.c[i].ent_kb_id = kb_id
|
||||
self.c[i].ent_iob = ent_iob
|
||||
span = Span(self, ent_start, ent_end, label=entity_type_, kb_id=kb_id)
|
||||
ent_spans.append(span)
|
||||
self.set_ents(ent_spans, default=SetEntsDefault.outside)
|
||||
|
||||
def set_ents(self, entities, *, blocked=None, missing=None, outside=None, default=SetEntsDefault.outside):
|
||||
"""Set entity annotation.
|
||||
|
@ -734,6 +701,9 @@ cdef class Doc:
|
|||
if default not in SetEntsDefault.values():
|
||||
raise ValueError(Errors.E1011.format(default=default, modes=", ".join(SetEntsDefault)))
|
||||
|
||||
# Ignore spans with missing labels
|
||||
entities = [ent for ent in entities if ent.label > 0]
|
||||
|
||||
if blocked is None:
|
||||
blocked = tuple()
|
||||
if missing is None:
|
||||
|
@ -742,6 +712,7 @@ cdef class Doc:
|
|||
outside = tuple()
|
||||
|
||||
# Find all tokens covered by spans and check that none are overlapping
|
||||
cdef int i
|
||||
seen_tokens = set()
|
||||
for span in itertools.chain.from_iterable([entities, blocked, missing, outside]):
|
||||
if not isinstance(span, Span):
|
||||
|
@ -761,6 +732,7 @@ cdef class Doc:
|
|||
else:
|
||||
self.c[i].ent_iob = 1
|
||||
self.c[i].ent_type = span.label
|
||||
self.c[i].ent_kb_id = span.kb_id
|
||||
for span in blocked:
|
||||
for i in range(span.start, span.end):
|
||||
self.c[i].ent_iob = 3
|
||||
|
|
Loading…
Reference in New Issue