mirror of https://github.com/explosion/spaCy.git
fix entity linker (cf PR #5548)
This commit is contained in:
parent
dc069e90b3
commit
c9242e9bf4
|
@ -1219,13 +1219,11 @@ class EntityLinker(Pipe):
|
|||
sent_doc = doc[start_token:end_token].as_doc()
|
||||
sentence_docs.append(sent_doc)
|
||||
|
||||
sentence_encodings, bp_context = self.model.begin_update(sentence_docs, drop=drop)
|
||||
loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds, docs=None)
|
||||
bp_context(d_scores, sgd=sgd)
|
||||
set_dropout_rate(self.model, drop)
|
||||
sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
|
||||
loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds)
|
||||
bp_context(d_scores)
|
||||
|
||||
if sgd is not None:
|
||||
self.model.finish_update(sgd)
|
||||
|
||||
|
@ -1306,22 +1304,28 @@ class EntityLinker(Pipe):
|
|||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
|
||||
|
||||
for i, doc in enumerate(docs):
|
||||
sentences = [s for s in doc.sents]
|
||||
|
||||
if len(doc) > 0:
|
||||
# Looping through each sentence and each entity
|
||||
# This may go wrong if there are entities across sentences - which shouldn't happen normally.
|
||||
for sent in doc.sents:
|
||||
sent_doc = sent.as_doc()
|
||||
for sent_index, sent in enumerate(sentences):
|
||||
# get n_neightbour sentences, clipped to the length of the document
|
||||
start_sentence = max(0, sent_index - self.n_sents)
|
||||
end_sentence = min(len(sentences) -1, sent_index + self.n_sents)
|
||||
|
||||
start_token = sentences[start_sentence].start
|
||||
end_token = sentences[end_sentence].end
|
||||
|
||||
sent_doc = doc[start_token:end_token].as_doc()
|
||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||
sentence_encoding = self.model.predict([sent_doc])[0]
|
||||
xp = get_array_module(sentence_encoding)
|
||||
sentence_encoding_t = sentence_encoding.T
|
||||
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
||||
|
||||
for ent in sent_doc.ents:
|
||||
for ent in sent.ents:
|
||||
entity_count += 1
|
||||
|
||||
to_discard = self.cfg.get("labels_discard", [])
|
||||
|
@ -1337,21 +1341,11 @@ class EntityLinker(Pipe):
|
|||
final_kb_ids.append(self.NIL)
|
||||
final_tensors.append(sentence_encoding)
|
||||
|
||||
sent_doc = doc[sent.start:sent.end].as_doc()
|
||||
elif len(candidates) == 1:
|
||||
# shortcut for efficiency reasons: take the 1 candidate
|
||||
|
||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||
sentence_encoding = self.model([sent_doc])[0]
|
||||
xp = get_array_module(sentence_encoding)
|
||||
sentence_encoding_t = sentence_encoding.T
|
||||
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
||||
|
||||
for ent in sent.ents:
|
||||
entity_count += 1
|
||||
|
||||
to_discard = self.cfg.get("labels_discard", [])
|
||||
if to_discard and ent.label_ in to_discard:
|
||||
# ignoring this entity - setting to NIL
|
||||
final_kb_ids.append(self.NIL)
|
||||
# TODO: thresholding
|
||||
final_kb_ids.append(candidates[0].entity_)
|
||||
final_tensors.append(sentence_encoding)
|
||||
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue