output tensors as part of predict

This commit is contained in:
svlandeg 2019-07-19 14:47:36 +02:00
parent 21176517a7
commit 41fb5204ba
1 changed files with 13 additions and 6 deletions

View File

@ -1212,15 +1212,15 @@ class EntityLinker(Pipe):
return loss, d_scores return loss, d_scores
def __call__(self, doc): def __call__(self, doc):
kb_ids = self.predict([doc]) kb_ids, tensors = self.predict([doc])
self.set_annotations([doc], kb_ids) self.set_annotations([doc], kb_ids, tensors=tensors)
return doc return doc
def pipe(self, stream, batch_size=128, n_threads=-1): def pipe(self, stream, batch_size=128, n_threads=-1):
for docs in util.minibatch(stream, size=batch_size): for docs in util.minibatch(stream, size=batch_size):
docs = list(docs) docs = list(docs)
kb_ids = self.predict(docs) kb_ids, tensors = self.predict(docs)
self.set_annotations(docs, kb_ids) self.set_annotations(docs, kb_ids, tensors=tensors)
yield from docs yield from docs
def predict(self, docs): def predict(self, docs):
@ -1230,6 +1230,7 @@ class EntityLinker(Pipe):
entity_count = 0 entity_count = 0
final_kb_ids = [] final_kb_ids = []
final_tensors = []
if not docs: if not docs:
return final_kb_ids return final_kb_ids
@ -1244,6 +1245,7 @@ class EntityLinker(Pipe):
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
if len(doc) > 0: if len(doc) > 0:
# currently, the context is the same for each entity in a sentence (should be refined)
context_encoding = context_encodings[i] context_encoding = context_encodings[i]
for ent in doc.ents: for ent in doc.ents:
entity_count += 1 entity_count += 1
@ -1254,6 +1256,7 @@ class EntityLinker(Pipe):
candidates = self.kb.get_candidates(ent.text) candidates = self.kb.get_candidates(ent.text)
if not candidates: if not candidates:
final_kb_ids.append(self.NIL) # no prediction possible for this entity final_kb_ids.append(self.NIL) # no prediction possible for this entity
final_tensors.append(context_encoding)
else: else:
random.shuffle(candidates) random.shuffle(candidates)
@ -1274,12 +1277,16 @@ class EntityLinker(Pipe):
best_index = scores.argmax() best_index = scores.argmax()
best_candidate = candidates[best_index] best_candidate = candidates[best_index]
final_kb_ids.append(best_candidate.entity_) final_kb_ids.append(best_candidate.entity_)
final_tensors.append(context_encoding)
assert len(final_kb_ids) == entity_count assert len(final_tensors) == len(final_kb_ids) == entity_count
return final_kb_ids return final_kb_ids, final_tensors
def set_annotations(self, docs, kb_ids, tensors=None): def set_annotations(self, docs, kb_ids, tensors=None):
count_ents = len([ent for doc in docs for ent in doc.ents])
assert count_ents == len(kb_ids)
i=0 i=0
for doc in docs: for doc in docs:
for ent in doc.ents: for ent in doc.ents: