mirror of https://github.com/explosion/spaCy.git
fix entity linker (cf PR #5548)
This commit is contained in:
parent
dc069e90b3
commit
c9242e9bf4
|
@ -1219,13 +1219,11 @@ class EntityLinker(Pipe):
|
||||||
sent_doc = doc[start_token:end_token].as_doc()
|
sent_doc = doc[start_token:end_token].as_doc()
|
||||||
sentence_docs.append(sent_doc)
|
sentence_docs.append(sent_doc)
|
||||||
|
|
||||||
sentence_encodings, bp_context = self.model.begin_update(sentence_docs, drop=drop)
|
|
||||||
loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds, docs=None)
|
|
||||||
bp_context(d_scores, sgd=sgd)
|
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
|
sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
|
||||||
loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds)
|
loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds)
|
||||||
bp_context(d_scores)
|
bp_context(d_scores)
|
||||||
|
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
self.model.finish_update(sgd)
|
self.model.finish_update(sgd)
|
||||||
|
|
||||||
|
@ -1306,22 +1304,28 @@ class EntityLinker(Pipe):
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
|
|
||||||
|
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
sentences = [s for s in doc.sents]
|
sentences = [s for s in doc.sents]
|
||||||
|
|
||||||
if len(doc) > 0:
|
if len(doc) > 0:
|
||||||
# Looping through each sentence and each entity
|
# Looping through each sentence and each entity
|
||||||
# This may go wrong if there are entities across sentences - which shouldn't happen normally.
|
# This may go wrong if there are entities across sentences - which shouldn't happen normally.
|
||||||
for sent in doc.sents:
|
for sent_index, sent in enumerate(sentences):
|
||||||
sent_doc = sent.as_doc()
|
# get n_neightbour sentences, clipped to the length of the document
|
||||||
|
start_sentence = max(0, sent_index - self.n_sents)
|
||||||
|
end_sentence = min(len(sentences) -1, sent_index + self.n_sents)
|
||||||
|
|
||||||
|
start_token = sentences[start_sentence].start
|
||||||
|
end_token = sentences[end_sentence].end
|
||||||
|
|
||||||
|
sent_doc = doc[start_token:end_token].as_doc()
|
||||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||||
sentence_encoding = self.model.predict([sent_doc])[0]
|
sentence_encoding = self.model.predict([sent_doc])[0]
|
||||||
xp = get_array_module(sentence_encoding)
|
xp = get_array_module(sentence_encoding)
|
||||||
sentence_encoding_t = sentence_encoding.T
|
sentence_encoding_t = sentence_encoding.T
|
||||||
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
||||||
|
|
||||||
for ent in sent_doc.ents:
|
for ent in sent.ents:
|
||||||
entity_count += 1
|
entity_count += 1
|
||||||
|
|
||||||
to_discard = self.cfg.get("labels_discard", [])
|
to_discard = self.cfg.get("labels_discard", [])
|
||||||
|
@ -1337,21 +1341,11 @@ class EntityLinker(Pipe):
|
||||||
final_kb_ids.append(self.NIL)
|
final_kb_ids.append(self.NIL)
|
||||||
final_tensors.append(sentence_encoding)
|
final_tensors.append(sentence_encoding)
|
||||||
|
|
||||||
sent_doc = doc[sent.start:sent.end].as_doc()
|
elif len(candidates) == 1:
|
||||||
|
# shortcut for efficiency reasons: take the 1 candidate
|
||||||
|
|
||||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
# TODO: thresholding
|
||||||
sentence_encoding = self.model([sent_doc])[0]
|
final_kb_ids.append(candidates[0].entity_)
|
||||||
xp = get_array_module(sentence_encoding)
|
|
||||||
sentence_encoding_t = sentence_encoding.T
|
|
||||||
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
|
||||||
|
|
||||||
for ent in sent.ents:
|
|
||||||
entity_count += 1
|
|
||||||
|
|
||||||
to_discard = self.cfg.get("labels_discard", [])
|
|
||||||
if to_discard and ent.label_ in to_discard:
|
|
||||||
# ignoring this entity - setting to NIL
|
|
||||||
final_kb_ids.append(self.NIL)
|
|
||||||
final_tensors.append(sentence_encoding)
|
final_tensors.append(sentence_encoding)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue