From bfffdeabb2ad16b65a1d5c2b0c0f088d47e7f7cc Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 6 Aug 2017 14:10:48 +0200 Subject: [PATCH] Fix parser batch-size bug introduced during cleanup --- spacy/syntax/nn_parser.pyx | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 66787c22a..4be31b4de 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -339,12 +339,10 @@ cdef class Parser: The number of threads with which to work on the buffer in parallel. Yields (Doc): Documents, in order. """ - cdef StateClass parse_state cdef Doc doc - queue = [] for docs in cytoolz.partition_all(batch_size, docs): docs = list(docs) - tokvecs = [d.tensor for d in docs] + tokvecs = [doc.tensor for doc in docs] if beam_width == 1: parse_states = self.parse_batch(docs, tokvecs) else: @@ -364,6 +362,8 @@ cdef class Parser: int nr_class, nr_feat, nr_piece, nr_dim, nr_state if isinstance(docs, Doc): docs = [docs] + if isinstance(tokvecses, np.ndarray): + tokvecses = [tokvecses] tokvecs = self.model[0].ops.flatten(tokvecses) @@ -395,14 +395,14 @@ cdef class Parser: st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat) self.moves.set_valid(&c_is_valid[i*nr_class], st) vectors = state2vec(token_ids[:next_step.size()]) - scores = vec2scores(vectors) - c_scores = scores.data - for i in range(next_step.size()): - st = next_step[i] - guess = arg_max_if_valid( - &c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class) - action = self.moves.c[guess] - action.do(st, action.label) + scores = vec2scores(vectors) + c_scores = scores.data + for i in range(next_step.size()): + st = next_step[i] + guess = arg_max_if_valid( + &c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class) + action = self.moves.c[guess] + action.do(st, action.label) this_step, next_step = next_step, this_step next_step.clear() for st in this_step: