mirror of https://github.com/explosion/spaCy.git
Refactor Doc.ents setter to use Doc.set_ents
Additional changes: * Entity spans with missing labels are ignored * Fix ent_kb_id setting in `Doc.set_ents`
This commit is contained in:
parent
b1a7d6c528
commit
8eaacaae97
|
@ -29,10 +29,10 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
|
||||||
ner.begin_training(lambda: [_ner_example(ner)])
|
ner.begin_training(lambda: [_ner_example(ner)])
|
||||||
ner(doc)
|
ner(doc)
|
||||||
|
|
||||||
doc.ents = [(doc.vocab.strings["ANIMAL"], 3, 4)]
|
doc.ents = [("ANIMAL", 3, 4)]
|
||||||
assert [w.ent_iob_ for w in doc] == ["O", "O", "O", "B"]
|
assert [w.ent_iob_ for w in doc] == ["O", "O", "O", "B"]
|
||||||
|
|
||||||
doc.ents = [(doc.vocab.strings["WORD"], 0, 2)]
|
doc.ents = [("WORD", 0, 2)]
|
||||||
assert [w.ent_iob_ for w in doc] == ["B", "I", "O", "O"]
|
assert [w.ent_iob_ for w in doc] == ["B", "I", "O", "O"]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -534,4 +534,4 @@ def test_doc_ents_setter():
|
||||||
vocab = Vocab()
|
vocab = Vocab()
|
||||||
ents = [("HELLO", 0, 2), (vocab.strings.add("WORLD"), 3, 5)]
|
ents = [("HELLO", 0, 2), (vocab.strings.add("WORLD"), 3, 5)]
|
||||||
doc = Doc(vocab, words=words, ents=ents)
|
doc = Doc(vocab, words=words, ents=ents)
|
||||||
assert [e.label_ for e in doc.ents] == ["HELLO", "WORLD"]
|
assert [e.label_ for e in doc.ents] == ["HELLO", "WORLD"]
|
||||||
|
|
|
@ -673,49 +673,16 @@ cdef class Doc:
|
||||||
# TODO:
|
# TODO:
|
||||||
# 1. Test basic data-driven ORTH gazetteer
|
# 1. Test basic data-driven ORTH gazetteer
|
||||||
# 2. Test more nuanced date and currency regex
|
# 2. Test more nuanced date and currency regex
|
||||||
tokens_in_ents = {}
|
cdef attr_t entity_type, kb_id
|
||||||
cdef attr_t entity_type
|
cdef int ent_start, ent_end
|
||||||
cdef attr_t kb_id
|
ent_spans = []
|
||||||
cdef int ent_start, ent_end, token_index
|
|
||||||
for ent_info in ents:
|
for ent_info in ents:
|
||||||
entity_type_, kb_id, ent_start, ent_end = get_entity_info(ent_info)
|
entity_type_, kb_id, ent_start, ent_end = get_entity_info(ent_info)
|
||||||
if isinstance(entity_type_, str):
|
if isinstance(entity_type_, str):
|
||||||
self.vocab.strings.add(entity_type_)
|
self.vocab.strings.add(entity_type_)
|
||||||
entity_type = self.vocab.strings.as_int(entity_type_)
|
span = Span(self, ent_start, ent_end, label=entity_type_, kb_id=kb_id)
|
||||||
for token_index in range(ent_start, ent_end):
|
ent_spans.append(span)
|
||||||
if token_index in tokens_in_ents:
|
self.set_ents(ent_spans, default=SetEntsDefault.outside)
|
||||||
raise ValueError(Errors.E103.format(
|
|
||||||
span1=(tokens_in_ents[token_index][0],
|
|
||||||
tokens_in_ents[token_index][1],
|
|
||||||
self.vocab.strings[tokens_in_ents[token_index][2]]),
|
|
||||||
span2=(ent_start, ent_end, self.vocab.strings[entity_type])))
|
|
||||||
tokens_in_ents[token_index] = (ent_start, ent_end, entity_type, kb_id)
|
|
||||||
cdef int i
|
|
||||||
for i in range(self.length):
|
|
||||||
# default values
|
|
||||||
entity_type = 0
|
|
||||||
kb_id = 0
|
|
||||||
|
|
||||||
# Set ent_iob to Outside (2) by default
|
|
||||||
ent_iob = 2
|
|
||||||
|
|
||||||
# overwrite if the token was part of a specified entity
|
|
||||||
if i in tokens_in_ents.keys():
|
|
||||||
ent_start, ent_end, entity_type, kb_id = tokens_in_ents[i]
|
|
||||||
if entity_type is None or entity_type <= 0:
|
|
||||||
# Only allow labelled spans
|
|
||||||
print(i, ent_start, ent_end, entity_type)
|
|
||||||
raise ValueError(Errors.E1013)
|
|
||||||
elif ent_start == i:
|
|
||||||
# Marking the start of an entity
|
|
||||||
ent_iob = 3
|
|
||||||
else:
|
|
||||||
# Marking the inside of an entity
|
|
||||||
ent_iob = 1
|
|
||||||
|
|
||||||
self.c[i].ent_type = entity_type
|
|
||||||
self.c[i].ent_kb_id = kb_id
|
|
||||||
self.c[i].ent_iob = ent_iob
|
|
||||||
|
|
||||||
def set_ents(self, entities, *, blocked=None, missing=None, outside=None, default=SetEntsDefault.outside):
|
def set_ents(self, entities, *, blocked=None, missing=None, outside=None, default=SetEntsDefault.outside):
|
||||||
"""Set entity annotation.
|
"""Set entity annotation.
|
||||||
|
@ -734,6 +701,9 @@ cdef class Doc:
|
||||||
if default not in SetEntsDefault.values():
|
if default not in SetEntsDefault.values():
|
||||||
raise ValueError(Errors.E1011.format(default=default, modes=", ".join(SetEntsDefault)))
|
raise ValueError(Errors.E1011.format(default=default, modes=", ".join(SetEntsDefault)))
|
||||||
|
|
||||||
|
# Ignore spans with missing labels
|
||||||
|
entities = [ent for ent in entities if ent.label > 0]
|
||||||
|
|
||||||
if blocked is None:
|
if blocked is None:
|
||||||
blocked = tuple()
|
blocked = tuple()
|
||||||
if missing is None:
|
if missing is None:
|
||||||
|
@ -742,6 +712,7 @@ cdef class Doc:
|
||||||
outside = tuple()
|
outside = tuple()
|
||||||
|
|
||||||
# Find all tokens covered by spans and check that none are overlapping
|
# Find all tokens covered by spans and check that none are overlapping
|
||||||
|
cdef int i
|
||||||
seen_tokens = set()
|
seen_tokens = set()
|
||||||
for span in itertools.chain.from_iterable([entities, blocked, missing, outside]):
|
for span in itertools.chain.from_iterable([entities, blocked, missing, outside]):
|
||||||
if not isinstance(span, Span):
|
if not isinstance(span, Span):
|
||||||
|
@ -761,6 +732,7 @@ cdef class Doc:
|
||||||
else:
|
else:
|
||||||
self.c[i].ent_iob = 1
|
self.c[i].ent_iob = 1
|
||||||
self.c[i].ent_type = span.label
|
self.c[i].ent_type = span.label
|
||||||
|
self.c[i].ent_kb_id = span.kb_id
|
||||||
for span in blocked:
|
for span in blocked:
|
||||||
for i in range(span.start, span.end):
|
for i in range(span.start, span.end):
|
||||||
self.c[i].ent_iob = 3
|
self.c[i].ent_iob = 3
|
||||||
|
|
Loading…
Reference in New Issue