output tensors as part of predict

This commit is contained in:
svlandeg 2019-07-19 14:47:36 +02:00
parent 21176517a7
commit 41fb5204ba
1 changed files with 13 additions and 6 deletions

View File

@ -1212,15 +1212,15 @@ class EntityLinker(Pipe):
return loss, d_scores
def __call__(self, doc):
kb_ids = self.predict([doc])
self.set_annotations([doc], kb_ids)
kb_ids, tensors = self.predict([doc])
self.set_annotations([doc], kb_ids, tensors=tensors)
return doc
def pipe(self, stream, batch_size=128, n_threads=-1):
for docs in util.minibatch(stream, size=batch_size):
docs = list(docs)
kb_ids = self.predict(docs)
self.set_annotations(docs, kb_ids)
kb_ids, tensors = self.predict(docs)
self.set_annotations(docs, kb_ids, tensors=tensors)
yield from docs
def predict(self, docs):
@ -1230,6 +1230,7 @@ class EntityLinker(Pipe):
entity_count = 0
final_kb_ids = []
final_tensors = []
if not docs:
return final_kb_ids
@ -1244,6 +1245,7 @@ class EntityLinker(Pipe):
for i, doc in enumerate(docs):
if len(doc) > 0:
# currently, the context is the same for each entity in a sentence (should be refined)
context_encoding = context_encodings[i]
for ent in doc.ents:
entity_count += 1
@ -1254,6 +1256,7 @@ class EntityLinker(Pipe):
candidates = self.kb.get_candidates(ent.text)
if not candidates:
final_kb_ids.append(self.NIL) # no prediction possible for this entity
final_tensors.append(context_encoding)
else:
random.shuffle(candidates)
@ -1274,12 +1277,16 @@ class EntityLinker(Pipe):
best_index = scores.argmax()
best_candidate = candidates[best_index]
final_kb_ids.append(best_candidate.entity_)
final_tensors.append(context_encoding)
assert len(final_kb_ids) == entity_count
assert len(final_tensors) == len(final_kb_ids) == entity_count
return final_kb_ids
return final_kb_ids, final_tensors
def set_annotations(self, docs, kb_ids, tensors=None):
count_ents = len([ent for doc in docs for ent in doc.ents])
assert count_ents == len(kb_ids)
i=0
for doc in docs:
for ent in doc.ents: