From 41fb5204ba47b71b37d012d06e8b039983fa0ef9 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Fri, 19 Jul 2019 14:47:36 +0200 Subject: [PATCH] output tensors as part of predict --- spacy/pipeline/pipes.pyx | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index a8746c73d..5704878b8 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1212,15 +1212,15 @@ class EntityLinker(Pipe): return loss, d_scores def __call__(self, doc): - kb_ids = self.predict([doc]) - self.set_annotations([doc], kb_ids) + kb_ids, tensors = self.predict([doc]) + self.set_annotations([doc], kb_ids, tensors=tensors) return doc def pipe(self, stream, batch_size=128, n_threads=-1): for docs in util.minibatch(stream, size=batch_size): docs = list(docs) - kb_ids = self.predict(docs) - self.set_annotations(docs, kb_ids) + kb_ids, tensors = self.predict(docs) + self.set_annotations(docs, kb_ids, tensors=tensors) yield from docs def predict(self, docs): @@ -1230,6 +1230,7 @@ class EntityLinker(Pipe): entity_count = 0 final_kb_ids = [] + final_tensors = [] if not docs: return final_kb_ids @@ -1244,6 +1245,7 @@ class EntityLinker(Pipe): for i, doc in enumerate(docs): if len(doc) > 0: + # currently, the context is the same for each entity in a sentence (should be refined) context_encoding = context_encodings[i] for ent in doc.ents: entity_count += 1 @@ -1254,6 +1256,7 @@ class EntityLinker(Pipe): candidates = self.kb.get_candidates(ent.text) if not candidates: final_kb_ids.append(self.NIL) # no prediction possible for this entity + final_tensors.append(context_encoding) else: random.shuffle(candidates) @@ -1274,12 +1277,16 @@ class EntityLinker(Pipe): best_index = scores.argmax() best_candidate = candidates[best_index] final_kb_ids.append(best_candidate.entity_) + final_tensors.append(context_encoding) - assert len(final_kb_ids) == entity_count + assert len(final_tensors) == len(final_kb_ids) == entity_count - return final_kb_ids + return final_kb_ids, final_tensors def set_annotations(self, docs, kb_ids, tensors=None): + count_ents = len([ent for doc in docs for ent in doc.ents]) + assert count_ents == len(kb_ids) + i=0 for doc in docs: for ent in doc.ents: