From 8eaacaae97f0caf77576e843a8d6bcf866c79236 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Thu, 24 Sep 2020 12:36:51 +0200 Subject: [PATCH] Refactor Doc.ents setter to use Doc.set_ents Additional changes: * Entity spans with missing labels are ignored * Fix ent_kb_id setting in `Doc.set_ents` --- spacy/tests/doc/test_add_entities.py | 4 +-- spacy/tests/doc/test_doc_api.py | 2 +- spacy/tokens/doc.pyx | 50 ++++++---------------------- 3 files changed, 14 insertions(+), 42 deletions(-) diff --git a/spacy/tests/doc/test_add_entities.py b/spacy/tests/doc/test_add_entities.py index 40aff8e31..615ab9e5b 100644 --- a/spacy/tests/doc/test_add_entities.py +++ b/spacy/tests/doc/test_add_entities.py @@ -29,10 +29,10 @@ def test_doc_add_entities_set_ents_iob(en_vocab): ner.begin_training(lambda: [_ner_example(ner)]) ner(doc) - doc.ents = [(doc.vocab.strings["ANIMAL"], 3, 4)] + doc.ents = [("ANIMAL", 3, 4)] assert [w.ent_iob_ for w in doc] == ["O", "O", "O", "B"] - doc.ents = [(doc.vocab.strings["WORD"], 0, 2)] + doc.ents = [("WORD", 0, 2)] assert [w.ent_iob_ for w in doc] == ["B", "I", "O", "O"] diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index 892b65cf4..e5e72fe2a 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -534,4 +534,4 @@ def test_doc_ents_setter(): vocab = Vocab() ents = [("HELLO", 0, 2), (vocab.strings.add("WORLD"), 3, 5)] doc = Doc(vocab, words=words, ents=ents) - assert [e.label_ for e in doc.ents] == ["HELLO", "WORLD"] \ No newline at end of file + assert [e.label_ for e in doc.ents] == ["HELLO", "WORLD"] diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 4bf6f0e5e..670c7440f 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -673,49 +673,16 @@ cdef class Doc: # TODO: # 1. Test basic data-driven ORTH gazetteer # 2. Test more nuanced date and currency regex - tokens_in_ents = {} - cdef attr_t entity_type - cdef attr_t kb_id - cdef int ent_start, ent_end, token_index + cdef attr_t entity_type, kb_id + cdef int ent_start, ent_end + ent_spans = [] for ent_info in ents: entity_type_, kb_id, ent_start, ent_end = get_entity_info(ent_info) if isinstance(entity_type_, str): self.vocab.strings.add(entity_type_) - entity_type = self.vocab.strings.as_int(entity_type_) - for token_index in range(ent_start, ent_end): - if token_index in tokens_in_ents: - raise ValueError(Errors.E103.format( - span1=(tokens_in_ents[token_index][0], - tokens_in_ents[token_index][1], - self.vocab.strings[tokens_in_ents[token_index][2]]), - span2=(ent_start, ent_end, self.vocab.strings[entity_type]))) - tokens_in_ents[token_index] = (ent_start, ent_end, entity_type, kb_id) - cdef int i - for i in range(self.length): - # default values - entity_type = 0 - kb_id = 0 - - # Set ent_iob to Outside (2) by default - ent_iob = 2 - - # overwrite if the token was part of a specified entity - if i in tokens_in_ents.keys(): - ent_start, ent_end, entity_type, kb_id = tokens_in_ents[i] - if entity_type is None or entity_type <= 0: - # Only allow labelled spans - print(i, ent_start, ent_end, entity_type) - raise ValueError(Errors.E1013) - elif ent_start == i: - # Marking the start of an entity - ent_iob = 3 - else: - # Marking the inside of an entity - ent_iob = 1 - - self.c[i].ent_type = entity_type - self.c[i].ent_kb_id = kb_id - self.c[i].ent_iob = ent_iob + span = Span(self, ent_start, ent_end, label=entity_type_, kb_id=kb_id) + ent_spans.append(span) + self.set_ents(ent_spans, default=SetEntsDefault.outside) def set_ents(self, entities, *, blocked=None, missing=None, outside=None, default=SetEntsDefault.outside): """Set entity annotation. @@ -734,6 +701,9 @@ cdef class Doc: if default not in SetEntsDefault.values(): raise ValueError(Errors.E1011.format(default=default, modes=", ".join(SetEntsDefault))) + # Ignore spans with missing labels + entities = [ent for ent in entities if ent.label > 0] + if blocked is None: blocked = tuple() if missing is None: @@ -742,6 +712,7 @@ cdef class Doc: outside = tuple() # Find all tokens covered by spans and check that none are overlapping + cdef int i seen_tokens = set() for span in itertools.chain.from_iterable([entities, blocked, missing, outside]): if not isinstance(span, Span): @@ -761,6 +732,7 @@ cdef class Doc: else: self.c[i].ent_iob = 1 self.c[i].ent_type = span.label + self.c[i].ent_kb_id = span.kb_id for span in blocked: for i in range(span.start, span.end): self.c[i].ent_iob = 3