mirror of https://github.com/explosion/spaCy.git
output tensors as part of predict
This commit is contained in:
parent
21176517a7
commit
41fb5204ba
|
@ -1212,15 +1212,15 @@ class EntityLinker(Pipe):
|
||||||
return loss, d_scores
|
return loss, d_scores
|
||||||
|
|
||||||
def __call__(self, doc):
|
def __call__(self, doc):
|
||||||
kb_ids = self.predict([doc])
|
kb_ids, tensors = self.predict([doc])
|
||||||
self.set_annotations([doc], kb_ids)
|
self.set_annotations([doc], kb_ids, tensors=tensors)
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
def pipe(self, stream, batch_size=128, n_threads=-1):
|
def pipe(self, stream, batch_size=128, n_threads=-1):
|
||||||
for docs in util.minibatch(stream, size=batch_size):
|
for docs in util.minibatch(stream, size=batch_size):
|
||||||
docs = list(docs)
|
docs = list(docs)
|
||||||
kb_ids = self.predict(docs)
|
kb_ids, tensors = self.predict(docs)
|
||||||
self.set_annotations(docs, kb_ids)
|
self.set_annotations(docs, kb_ids, tensors=tensors)
|
||||||
yield from docs
|
yield from docs
|
||||||
|
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
|
@ -1230,6 +1230,7 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
entity_count = 0
|
entity_count = 0
|
||||||
final_kb_ids = []
|
final_kb_ids = []
|
||||||
|
final_tensors = []
|
||||||
|
|
||||||
if not docs:
|
if not docs:
|
||||||
return final_kb_ids
|
return final_kb_ids
|
||||||
|
@ -1244,6 +1245,7 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
if len(doc) > 0:
|
if len(doc) > 0:
|
||||||
|
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||||
context_encoding = context_encodings[i]
|
context_encoding = context_encodings[i]
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
entity_count += 1
|
entity_count += 1
|
||||||
|
@ -1254,6 +1256,7 @@ class EntityLinker(Pipe):
|
||||||
candidates = self.kb.get_candidates(ent.text)
|
candidates = self.kb.get_candidates(ent.text)
|
||||||
if not candidates:
|
if not candidates:
|
||||||
final_kb_ids.append(self.NIL) # no prediction possible for this entity
|
final_kb_ids.append(self.NIL) # no prediction possible for this entity
|
||||||
|
final_tensors.append(context_encoding)
|
||||||
else:
|
else:
|
||||||
random.shuffle(candidates)
|
random.shuffle(candidates)
|
||||||
|
|
||||||
|
@ -1274,12 +1277,16 @@ class EntityLinker(Pipe):
|
||||||
best_index = scores.argmax()
|
best_index = scores.argmax()
|
||||||
best_candidate = candidates[best_index]
|
best_candidate = candidates[best_index]
|
||||||
final_kb_ids.append(best_candidate.entity_)
|
final_kb_ids.append(best_candidate.entity_)
|
||||||
|
final_tensors.append(context_encoding)
|
||||||
|
|
||||||
assert len(final_kb_ids) == entity_count
|
assert len(final_tensors) == len(final_kb_ids) == entity_count
|
||||||
|
|
||||||
return final_kb_ids
|
return final_kb_ids, final_tensors
|
||||||
|
|
||||||
def set_annotations(self, docs, kb_ids, tensors=None):
|
def set_annotations(self, docs, kb_ids, tensors=None):
|
||||||
|
count_ents = len([ent for doc in docs for ent in doc.ents])
|
||||||
|
assert count_ents == len(kb_ids)
|
||||||
|
|
||||||
i=0
|
i=0
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
|
|
Loading…
Reference in New Issue