Fix tensor extending in tagger

This commit is contained in:
Matthew Honnibal 2017-11-03 13:29:36 +01:00
parent bd2cbdfa85
commit 6681058abd
1 changed files with 6 additions and 4 deletions

View File

@ -352,10 +352,12 @@ class Tagger(Pipe):
def predict(self, docs): def predict(self, docs):
tokvecs = self.model.tok2vec(docs) tokvecs = self.model.tok2vec(docs)
scores = self.model.softmax(tokvecs) scores = self.model.softmax(tokvecs)
guesses = scores.argmax(axis=1) guesses = []
if not isinstance(guesses, numpy.ndarray): for doc_scores in scores:
guesses = guesses.get() doc_guesses = doc_scores.argmax(axis=1)
guesses = self.model.ops.unflatten(guesses, [len(d) for d in docs]) if not isinstance(doc_guesses, numpy.ndarray):
doc_guesses = doc_guesses.get()
guesses.append(doc_guesses)
return guesses, tokvecs return guesses, tokvecs
def set_annotations(self, docs, batch_tag_ids, tensors=None): def set_annotations(self, docs, batch_tag_ids, tensors=None):