diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 7c800eed8..536c2a8a5 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1219,13 +1219,11 @@ class EntityLinker(Pipe): sent_doc = doc[start_token:end_token].as_doc() sentence_docs.append(sent_doc) - sentence_encodings, bp_context = self.model.begin_update(sentence_docs, drop=drop) - loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds, docs=None) - bp_context(d_scores, sgd=sgd) set_dropout_rate(self.model, drop) sentence_encodings, bp_context = self.model.begin_update(sentence_docs) loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds) bp_context(d_scores) + if sgd is not None: self.model.finish_update(sgd) @@ -1306,22 +1304,28 @@ class EntityLinker(Pipe): if isinstance(docs, Doc): docs = [docs] - for i, doc in enumerate(docs): sentences = [s for s in doc.sents] if len(doc) > 0: # Looping through each sentence and each entity # This may go wrong if there are entities across sentences - which shouldn't happen normally. - for sent in doc.sents: - sent_doc = sent.as_doc() + for sent_index, sent in enumerate(sentences): + # get n_neightbour sentences, clipped to the length of the document + start_sentence = max(0, sent_index - self.n_sents) + end_sentence = min(len(sentences) -1, sent_index + self.n_sents) + + start_token = sentences[start_sentence].start + end_token = sentences[end_sentence].end + + sent_doc = doc[start_token:end_token].as_doc() # currently, the context is the same for each entity in a sentence (should be refined) sentence_encoding = self.model.predict([sent_doc])[0] xp = get_array_module(sentence_encoding) sentence_encoding_t = sentence_encoding.T sentence_norm = xp.linalg.norm(sentence_encoding_t) - for ent in sent_doc.ents: + for ent in sent.ents: entity_count += 1 to_discard = self.cfg.get("labels_discard", []) @@ -1337,21 +1341,11 @@ class EntityLinker(Pipe): final_kb_ids.append(self.NIL) final_tensors.append(sentence_encoding) - sent_doc = doc[sent.start:sent.end].as_doc() + elif len(candidates) == 1: + # shortcut for efficiency reasons: take the 1 candidate - # currently, the context is the same for each entity in a sentence (should be refined) - sentence_encoding = self.model([sent_doc])[0] - xp = get_array_module(sentence_encoding) - sentence_encoding_t = sentence_encoding.T - sentence_norm = xp.linalg.norm(sentence_encoding_t) - - for ent in sent.ents: - entity_count += 1 - - to_discard = self.cfg.get("labels_discard", []) - if to_discard and ent.label_ in to_discard: - # ignoring this entity - setting to NIL - final_kb_ids.append(self.NIL) + # TODO: thresholding + final_kb_ids.append(candidates[0].entity_) final_tensors.append(sentence_encoding) else: