diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py index fba4b485f..287feecce 100644 --- a/spacy/ml/models/entity_linker.py +++ b/spacy/ml/models/entity_linker.py @@ -23,7 +23,7 @@ def build_nel_encoder( ((tok2vec >> list2ragged()) & build_span_maker()) >> extract_spans() >> reduce_mean() - >> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0)) + >> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0)) # type: ignore >> output_layer ) model.set_ref("output_layer", output_layer) diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 12c3e382f..aa7985a9c 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -355,7 +355,7 @@ class EntityLinker(TrainablePipe): keep_ents.append(eidx) eidx += 1 - entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32") + entity_encodings = self.model.ops.asarray2f(entity_encodings, dtype="float32") selected_encodings = sentence_encodings[keep_ents] # if there are no matches, short circuit @@ -368,13 +368,12 @@ class EntityLinker(TrainablePipe): method="get_loss", msg="gold entities do not match up" ) raise RuntimeError(err) - # TODO: fix typing issue here - gradients = self.distance.get_grad(selected_encodings, entity_encodings) # type: ignore + gradients = self.distance.get_grad(selected_encodings, entity_encodings) # to match the input size, we need to give a zero gradient for items not in the kb out = self.model.ops.alloc2f(*sentence_encodings.shape) out[keep_ents] = gradients - loss = self.distance.get_loss(selected_encodings, entity_encodings) # type: ignore + loss = self.distance.get_loss(selected_encodings, entity_encodings) loss = loss / len(entity_encodings) return float(loss), out @@ -391,18 +390,21 @@ class EntityLinker(TrainablePipe): self.validate_kb() entity_count = 0 final_kb_ids: List[str] = [] + xp = self.model.ops.xp if not docs: return final_kb_ids if isinstance(docs, Doc): docs = [docs] for i, doc in enumerate(docs): + if len(doc) == 0: + continue sentences = [s for s in doc.sents] - if len(doc) > 0: - # Looping through each entity (TODO: rewrite) - for ent in doc.ents: - sent = ent.sent - sent_index = sentences.index(sent) - assert sent_index >= 0 + # Looping through each entity (TODO: rewrite) + for ent in doc.ents: + sent_index = sentences.index(ent.sent) + assert sent_index >= 0 + + if self.incl_context: # get n_neighbour sentences, clipped to the length of the document start_sentence = max(0, sent_index - self.n_sents) end_sentence = min(len(sentences) - 1, sent_index + self.n_sents) @@ -410,55 +412,53 @@ class EntityLinker(TrainablePipe): end_token = sentences[end_sentence].end sent_doc = doc[start_token:end_token].as_doc() # currently, the context is the same for each entity in a sentence (should be refined) - xp = self.model.ops.xp - if self.incl_context: - sentence_encoding = self.model.predict([sent_doc])[0] - sentence_encoding_t = sentence_encoding.T - sentence_norm = xp.linalg.norm(sentence_encoding_t) - entity_count += 1 - if ent.label_ in self.labels_discard: - # ignoring this entity - setting to NIL + sentence_encoding = self.model.predict([sent_doc])[0] + sentence_encoding_t = sentence_encoding.T + sentence_norm = xp.linalg.norm(sentence_encoding_t) + entity_count += 1 + if ent.label_ in self.labels_discard: + # ignoring this entity - setting to NIL + final_kb_ids.append(self.NIL) + else: + candidates = list(self.get_candidates(self.kb, ent)) + if not candidates: + # no prediction possible for this entity - setting to NIL final_kb_ids.append(self.NIL) + elif len(candidates) == 1: + # shortcut for efficiency reasons: take the 1 candidate + # TODO: thresholding + final_kb_ids.append(candidates[0].entity_) else: - candidates = list(self.get_candidates(self.kb, ent)) - if not candidates: - # no prediction possible for this entity - setting to NIL - final_kb_ids.append(self.NIL) - elif len(candidates) == 1: - # shortcut for efficiency reasons: take the 1 candidate - # TODO: thresholding - final_kb_ids.append(candidates[0].entity_) - else: - random.shuffle(candidates) - # set all prior probabilities to 0 if incl_prior=False - prior_probs = xp.asarray([c.prior_prob for c in candidates]) - if not self.incl_prior: - prior_probs = xp.asarray([0.0 for _ in candidates]) - scores = prior_probs - # add in similarity from the context - if self.incl_context: - entity_encodings = xp.asarray( - [c.entity_vector for c in candidates] - ) - entity_norm = xp.linalg.norm(entity_encodings, axis=1) - if len(entity_encodings) != len(prior_probs): - raise RuntimeError( - Errors.E147.format( - method="predict", - msg="vectors not of equal length", - ) + random.shuffle(candidates) + # set all prior probabilities to 0 if incl_prior=False + prior_probs = xp.asarray([c.prior_prob for c in candidates]) + if not self.incl_prior: + prior_probs = xp.asarray([0.0 for _ in candidates]) + scores = prior_probs + # add in similarity from the context + if self.incl_context: + entity_encodings = xp.asarray( + [c.entity_vector for c in candidates] + ) + entity_norm = xp.linalg.norm(entity_encodings, axis=1) + if len(entity_encodings) != len(prior_probs): + raise RuntimeError( + Errors.E147.format( + method="predict", + msg="vectors not of equal length", ) - # cosine similarity - sims = xp.dot(entity_encodings, sentence_encoding_t) / ( - sentence_norm * entity_norm ) - if sims.shape != prior_probs.shape: - raise ValueError(Errors.E161) - scores = prior_probs + sims - (prior_probs * sims) - # TODO: thresholding - best_index = scores.argmax().item() - best_candidate = candidates[best_index] - final_kb_ids.append(best_candidate.entity_) + # cosine similarity + sims = xp.dot(entity_encodings, sentence_encoding_t) / ( + sentence_norm * entity_norm + ) + if sims.shape != prior_probs.shape: + raise ValueError(Errors.E161) + scores = prior_probs + sims - (prior_probs * sims) + # TODO: thresholding + best_index = scores.argmax().item() + best_candidate = candidates[best_index] + final_kb_ids.append(best_candidate.entity_) if not (len(final_kb_ids) == entity_count): err = Errors.E147.format( method="predict", msg="result variables not of equal length"