diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 98414736b..4e04b96b5 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1302,71 +1302,72 @@ class EntityLinker(Pipe): # Looping through each sentence and each entity # This may go wrong if there are entities across sentences - which shouldn't happen normally. for sent_index, sent in enumerate(sentences): - # get n_neightbour sentences, clipped to the length of the document - start_sentence = max(0, sent_index - self.n_sents) - end_sentence = min(len(sentences) -1, sent_index + self.n_sents) + if sent.ents: + # get n_neightbour sentences, clipped to the length of the document + start_sentence = max(0, sent_index - self.n_sents) + end_sentence = min(len(sentences) -1, sent_index + self.n_sents) - start_token = sentences[start_sentence].start - end_token = sentences[end_sentence].end + start_token = sentences[start_sentence].start + end_token = sentences[end_sentence].end - sent_doc = doc[start_token:end_token].as_doc() - # currently, the context is the same for each entity in a sentence (should be refined) - sentence_encoding = self.model.predict([sent_doc])[0] - xp = get_array_module(sentence_encoding) - sentence_encoding_t = sentence_encoding.T - sentence_norm = xp.linalg.norm(sentence_encoding_t) + sent_doc = doc[start_token:end_token].as_doc() + # currently, the context is the same for each entity in a sentence (should be refined) + sentence_encoding = self.model.predict([sent_doc])[0] + xp = get_array_module(sentence_encoding) + sentence_encoding_t = sentence_encoding.T + sentence_norm = xp.linalg.norm(sentence_encoding_t) - for ent in sent.ents: - entity_count += 1 + for ent in sent.ents: + entity_count += 1 - to_discard = self.cfg.get("labels_discard", []) - if to_discard and ent.label_ in to_discard: - # ignoring this entity - setting to NIL - final_kb_ids.append(self.NIL) - final_tensors.append(sentence_encoding) - - else: - candidates = self.kb.get_candidates(ent.text) - if not candidates: - # no prediction possible for this entity - setting to NIL + to_discard = self.cfg.get("labels_discard", []) + if to_discard and ent.label_ in to_discard: + # ignoring this entity - setting to NIL final_kb_ids.append(self.NIL) final_tensors.append(sentence_encoding) - elif len(candidates) == 1: - # shortcut for efficiency reasons: take the 1 candidate - - # TODO: thresholding - final_kb_ids.append(candidates[0].entity_) - final_tensors.append(sentence_encoding) - else: - random.shuffle(candidates) + candidates = self.kb.get_candidates(ent.text) + if not candidates: + # no prediction possible for this entity - setting to NIL + final_kb_ids.append(self.NIL) + final_tensors.append(sentence_encoding) - # this will set all prior probabilities to 0 if they should be excluded from the model - prior_probs = xp.asarray([c.prior_prob for c in candidates]) - if not self.cfg.get("incl_prior", True): - prior_probs = xp.asarray([0.0 for c in candidates]) - scores = prior_probs + elif len(candidates) == 1: + # shortcut for efficiency reasons: take the 1 candidate - # add in similarity from the context - if self.cfg.get("incl_context", True): - entity_encodings = xp.asarray([c.entity_vector for c in candidates]) - entity_norm = xp.linalg.norm(entity_encodings, axis=1) + # TODO: thresholding + final_kb_ids.append(candidates[0].entity_) + final_tensors.append(sentence_encoding) - if len(entity_encodings) != len(prior_probs): - raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length")) + else: + random.shuffle(candidates) - # cosine similarity - sims = xp.dot(entity_encodings, sentence_encoding_t) / (sentence_norm * entity_norm) - if sims.shape != prior_probs.shape: - raise ValueError(Errors.E161) - scores = prior_probs + sims - (prior_probs*sims) + # this will set all prior probabilities to 0 if they should be excluded from the model + prior_probs = xp.asarray([c.prior_prob for c in candidates]) + if not self.cfg.get("incl_prior", True): + prior_probs = xp.asarray([0.0 for c in candidates]) + scores = prior_probs - # TODO: thresholding - best_index = scores.argmax().item() - best_candidate = candidates[best_index] - final_kb_ids.append(best_candidate.entity_) - final_tensors.append(sentence_encoding) + # add in similarity from the context + if self.cfg.get("incl_context", True): + entity_encodings = xp.asarray([c.entity_vector for c in candidates]) + entity_norm = xp.linalg.norm(entity_encodings, axis=1) + + if len(entity_encodings) != len(prior_probs): + raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length")) + + # cosine similarity + sims = xp.dot(entity_encodings, sentence_encoding_t) / (sentence_norm * entity_norm) + if sims.shape != prior_probs.shape: + raise ValueError(Errors.E161) + scores = prior_probs + sims - (prior_probs*sims) + + # TODO: thresholding + best_index = scores.argmax().item() + best_candidate = candidates[best_index] + final_kb_ids.append(best_candidate.entity_) + final_tensors.append(sentence_encoding) if not (len(final_tensors) == len(final_kb_ids) == entity_count): raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))