diff --git a/spacy/pipeline.pyx b/spacy/pipeline.pyx index 283e4b106..e55710dee 100644 --- a/spacy/pipeline.pyx +++ b/spacy/pipeline.pyx @@ -352,10 +352,12 @@ class Tagger(Pipe): def predict(self, docs): tokvecs = self.model.tok2vec(docs) scores = self.model.softmax(tokvecs) - guesses = scores.argmax(axis=1) - if not isinstance(guesses, numpy.ndarray): - guesses = guesses.get() - guesses = self.model.ops.unflatten(guesses, [len(d) for d in docs]) + guesses = [] + for doc_scores in scores: + doc_guesses = doc_scores.argmax(axis=1) + if not isinstance(doc_guesses, numpy.ndarray): + doc_guesses = doc_guesses.get() + guesses.append(doc_guesses) return guesses, tokvecs def set_annotations(self, docs, batch_tag_ids, tensors=None):