diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index f81e4a96b..b82bab294 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -663,11 +663,14 @@ cdef class Doc: tokens_in_ents = {} cdef attr_t entity_type cdef attr_t kb_id - cdef int ent_start, ent_end + cdef int ent_start, ent_end, token_index for ent_info in ents: - entity_type, kb_id, ent_start, ent_end = get_entity_info(ent_info, self.vocab) + entity_type_, kb_id, ent_start, ent_end = get_entity_info(ent_info) + if isinstance(entity_type_, str): + self.vocab.strings.add(entity_type_) + entity_type = self.vocab.strings.as_int(entity_type_) for token_index in range(ent_start, ent_end): - if token_index in tokens_in_ents.keys(): + if token_index in tokens_in_ents: raise ValueError(Errors.E103.format( span1=(tokens_in_ents[token_index][0], tokens_in_ents[token_index][1], @@ -1583,7 +1586,7 @@ def fix_attributes(doc, attributes): attributes[ENT_TYPE] = attributes["ent_type"] -def get_entity_info(ent_info, vocab): +def get_entity_info(ent_info): if isinstance(ent_info, Span): ent_type = ent_info.label ent_kb_id = ent_info.kb_id @@ -1596,6 +1599,4 @@ def get_entity_info(ent_info, vocab): ent_type, ent_kb_id, start, end = ent_info else: ent_id, ent_kb_id, ent_type, start, end = ent_info - if isinstance(ent_type, str): - ent_type = vocab.strings.add(ent_type) return ent_type, ent_kb_id, start, end