add line that got removed from EntityLinker

This commit is contained in:
svlandeg 2020-06-20 23:14:45 +02:00
parent 12dc8ab208
commit 2f6062a8a4
1 changed files with 53 additions and 52 deletions

View File

@ -1302,71 +1302,72 @@ class EntityLinker(Pipe):
# Looping through each sentence and each entity # Looping through each sentence and each entity
# This may go wrong if there are entities across sentences - which shouldn't happen normally. # This may go wrong if there are entities across sentences - which shouldn't happen normally.
for sent_index, sent in enumerate(sentences): for sent_index, sent in enumerate(sentences):
# get n_neightbour sentences, clipped to the length of the document if sent.ents:
start_sentence = max(0, sent_index - self.n_sents) # get n_neightbour sentences, clipped to the length of the document
end_sentence = min(len(sentences) -1, sent_index + self.n_sents) start_sentence = max(0, sent_index - self.n_sents)
end_sentence = min(len(sentences) -1, sent_index + self.n_sents)
start_token = sentences[start_sentence].start start_token = sentences[start_sentence].start
end_token = sentences[end_sentence].end end_token = sentences[end_sentence].end
sent_doc = doc[start_token:end_token].as_doc() sent_doc = doc[start_token:end_token].as_doc()
# currently, the context is the same for each entity in a sentence (should be refined) # currently, the context is the same for each entity in a sentence (should be refined)
sentence_encoding = self.model.predict([sent_doc])[0] sentence_encoding = self.model.predict([sent_doc])[0]
xp = get_array_module(sentence_encoding) xp = get_array_module(sentence_encoding)
sentence_encoding_t = sentence_encoding.T sentence_encoding_t = sentence_encoding.T
sentence_norm = xp.linalg.norm(sentence_encoding_t) sentence_norm = xp.linalg.norm(sentence_encoding_t)
for ent in sent.ents: for ent in sent.ents:
entity_count += 1 entity_count += 1
to_discard = self.cfg.get("labels_discard", []) to_discard = self.cfg.get("labels_discard", [])
if to_discard and ent.label_ in to_discard: if to_discard and ent.label_ in to_discard:
# ignoring this entity - setting to NIL # ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL)
final_tensors.append(sentence_encoding)
else:
candidates = self.kb.get_candidates(ent.text)
if not candidates:
# no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
final_tensors.append(sentence_encoding) final_tensors.append(sentence_encoding)
elif len(candidates) == 1:
# shortcut for efficiency reasons: take the 1 candidate
# TODO: thresholding
final_kb_ids.append(candidates[0].entity_)
final_tensors.append(sentence_encoding)
else: else:
random.shuffle(candidates) candidates = self.kb.get_candidates(ent.text)
if not candidates:
# no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL)
final_tensors.append(sentence_encoding)
# this will set all prior probabilities to 0 if they should be excluded from the model elif len(candidates) == 1:
prior_probs = xp.asarray([c.prior_prob for c in candidates]) # shortcut for efficiency reasons: take the 1 candidate
if not self.cfg.get("incl_prior", True):
prior_probs = xp.asarray([0.0 for c in candidates])
scores = prior_probs
# add in similarity from the context # TODO: thresholding
if self.cfg.get("incl_context", True): final_kb_ids.append(candidates[0].entity_)
entity_encodings = xp.asarray([c.entity_vector for c in candidates]) final_tensors.append(sentence_encoding)
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
if len(entity_encodings) != len(prior_probs): else:
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length")) random.shuffle(candidates)
# cosine similarity # this will set all prior probabilities to 0 if they should be excluded from the model
sims = xp.dot(entity_encodings, sentence_encoding_t) / (sentence_norm * entity_norm) prior_probs = xp.asarray([c.prior_prob for c in candidates])
if sims.shape != prior_probs.shape: if not self.cfg.get("incl_prior", True):
raise ValueError(Errors.E161) prior_probs = xp.asarray([0.0 for c in candidates])
scores = prior_probs + sims - (prior_probs*sims) scores = prior_probs
# TODO: thresholding # add in similarity from the context
best_index = scores.argmax().item() if self.cfg.get("incl_context", True):
best_candidate = candidates[best_index] entity_encodings = xp.asarray([c.entity_vector for c in candidates])
final_kb_ids.append(best_candidate.entity_) entity_norm = xp.linalg.norm(entity_encodings, axis=1)
final_tensors.append(sentence_encoding)
if len(entity_encodings) != len(prior_probs):
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
# cosine similarity
sims = xp.dot(entity_encodings, sentence_encoding_t) / (sentence_norm * entity_norm)
if sims.shape != prior_probs.shape:
raise ValueError(Errors.E161)
scores = prior_probs + sims - (prior_probs*sims)
# TODO: thresholding
best_index = scores.argmax().item()
best_candidate = candidates[best_index]
final_kb_ids.append(best_candidate.entity_)
final_tensors.append(sentence_encoding)
if not (len(final_tensors) == len(final_kb_ids) == entity_count): if not (len(final_tensors) == len(final_kb_ids) == entity_count):
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length")) raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))