mirror of https://github.com/explosion/spaCy.git
add line that got removed from EntityLinker
This commit is contained in:
parent
12dc8ab208
commit
2f6062a8a4
|
@ -1302,71 +1302,72 @@ class EntityLinker(Pipe):
|
||||||
# Looping through each sentence and each entity
|
# Looping through each sentence and each entity
|
||||||
# This may go wrong if there are entities across sentences - which shouldn't happen normally.
|
# This may go wrong if there are entities across sentences - which shouldn't happen normally.
|
||||||
for sent_index, sent in enumerate(sentences):
|
for sent_index, sent in enumerate(sentences):
|
||||||
# get n_neightbour sentences, clipped to the length of the document
|
if sent.ents:
|
||||||
start_sentence = max(0, sent_index - self.n_sents)
|
# get n_neightbour sentences, clipped to the length of the document
|
||||||
end_sentence = min(len(sentences) -1, sent_index + self.n_sents)
|
start_sentence = max(0, sent_index - self.n_sents)
|
||||||
|
end_sentence = min(len(sentences) -1, sent_index + self.n_sents)
|
||||||
|
|
||||||
start_token = sentences[start_sentence].start
|
start_token = sentences[start_sentence].start
|
||||||
end_token = sentences[end_sentence].end
|
end_token = sentences[end_sentence].end
|
||||||
|
|
||||||
sent_doc = doc[start_token:end_token].as_doc()
|
sent_doc = doc[start_token:end_token].as_doc()
|
||||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||||
sentence_encoding = self.model.predict([sent_doc])[0]
|
sentence_encoding = self.model.predict([sent_doc])[0]
|
||||||
xp = get_array_module(sentence_encoding)
|
xp = get_array_module(sentence_encoding)
|
||||||
sentence_encoding_t = sentence_encoding.T
|
sentence_encoding_t = sentence_encoding.T
|
||||||
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
||||||
|
|
||||||
for ent in sent.ents:
|
for ent in sent.ents:
|
||||||
entity_count += 1
|
entity_count += 1
|
||||||
|
|
||||||
to_discard = self.cfg.get("labels_discard", [])
|
to_discard = self.cfg.get("labels_discard", [])
|
||||||
if to_discard and ent.label_ in to_discard:
|
if to_discard and ent.label_ in to_discard:
|
||||||
# ignoring this entity - setting to NIL
|
# ignoring this entity - setting to NIL
|
||||||
final_kb_ids.append(self.NIL)
|
|
||||||
final_tensors.append(sentence_encoding)
|
|
||||||
|
|
||||||
else:
|
|
||||||
candidates = self.kb.get_candidates(ent.text)
|
|
||||||
if not candidates:
|
|
||||||
# no prediction possible for this entity - setting to NIL
|
|
||||||
final_kb_ids.append(self.NIL)
|
final_kb_ids.append(self.NIL)
|
||||||
final_tensors.append(sentence_encoding)
|
final_tensors.append(sentence_encoding)
|
||||||
|
|
||||||
elif len(candidates) == 1:
|
|
||||||
# shortcut for efficiency reasons: take the 1 candidate
|
|
||||||
|
|
||||||
# TODO: thresholding
|
|
||||||
final_kb_ids.append(candidates[0].entity_)
|
|
||||||
final_tensors.append(sentence_encoding)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
random.shuffle(candidates)
|
candidates = self.kb.get_candidates(ent.text)
|
||||||
|
if not candidates:
|
||||||
|
# no prediction possible for this entity - setting to NIL
|
||||||
|
final_kb_ids.append(self.NIL)
|
||||||
|
final_tensors.append(sentence_encoding)
|
||||||
|
|
||||||
# this will set all prior probabilities to 0 if they should be excluded from the model
|
elif len(candidates) == 1:
|
||||||
prior_probs = xp.asarray([c.prior_prob for c in candidates])
|
# shortcut for efficiency reasons: take the 1 candidate
|
||||||
if not self.cfg.get("incl_prior", True):
|
|
||||||
prior_probs = xp.asarray([0.0 for c in candidates])
|
|
||||||
scores = prior_probs
|
|
||||||
|
|
||||||
# add in similarity from the context
|
# TODO: thresholding
|
||||||
if self.cfg.get("incl_context", True):
|
final_kb_ids.append(candidates[0].entity_)
|
||||||
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
|
final_tensors.append(sentence_encoding)
|
||||||
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
|
|
||||||
|
|
||||||
if len(entity_encodings) != len(prior_probs):
|
else:
|
||||||
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
|
random.shuffle(candidates)
|
||||||
|
|
||||||
# cosine similarity
|
# this will set all prior probabilities to 0 if they should be excluded from the model
|
||||||
sims = xp.dot(entity_encodings, sentence_encoding_t) / (sentence_norm * entity_norm)
|
prior_probs = xp.asarray([c.prior_prob for c in candidates])
|
||||||
if sims.shape != prior_probs.shape:
|
if not self.cfg.get("incl_prior", True):
|
||||||
raise ValueError(Errors.E161)
|
prior_probs = xp.asarray([0.0 for c in candidates])
|
||||||
scores = prior_probs + sims - (prior_probs*sims)
|
scores = prior_probs
|
||||||
|
|
||||||
# TODO: thresholding
|
# add in similarity from the context
|
||||||
best_index = scores.argmax().item()
|
if self.cfg.get("incl_context", True):
|
||||||
best_candidate = candidates[best_index]
|
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
|
||||||
final_kb_ids.append(best_candidate.entity_)
|
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
|
||||||
final_tensors.append(sentence_encoding)
|
|
||||||
|
if len(entity_encodings) != len(prior_probs):
|
||||||
|
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
|
||||||
|
|
||||||
|
# cosine similarity
|
||||||
|
sims = xp.dot(entity_encodings, sentence_encoding_t) / (sentence_norm * entity_norm)
|
||||||
|
if sims.shape != prior_probs.shape:
|
||||||
|
raise ValueError(Errors.E161)
|
||||||
|
scores = prior_probs + sims - (prior_probs*sims)
|
||||||
|
|
||||||
|
# TODO: thresholding
|
||||||
|
best_index = scores.argmax().item()
|
||||||
|
best_candidate = candidates[best_index]
|
||||||
|
final_kb_ids.append(best_candidate.entity_)
|
||||||
|
final_tensors.append(sentence_encoding)
|
||||||
|
|
||||||
if not (len(final_tensors) == len(final_kb_ids) == entity_count):
|
if not (len(final_tensors) == len(final_kb_ids) == entity_count):
|
||||||
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
|
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
|
||||||
|
|
Loading…
Reference in New Issue