mirror of https://github.com/explosion/spaCy.git
Fix parser batch-size bug introduced during cleanup
This commit is contained in:
parent
0eec7c9e9b
commit
bfffdeabb2
|
@ -339,12 +339,10 @@ cdef class Parser:
|
|||
The number of threads with which to work on the buffer in parallel.
|
||||
Yields (Doc): Documents, in order.
|
||||
"""
|
||||
cdef StateClass parse_state
|
||||
cdef Doc doc
|
||||
queue = []
|
||||
for docs in cytoolz.partition_all(batch_size, docs):
|
||||
docs = list(docs)
|
||||
tokvecs = [d.tensor for d in docs]
|
||||
tokvecs = [doc.tensor for doc in docs]
|
||||
if beam_width == 1:
|
||||
parse_states = self.parse_batch(docs, tokvecs)
|
||||
else:
|
||||
|
@ -364,6 +362,8 @@ cdef class Parser:
|
|||
int nr_class, nr_feat, nr_piece, nr_dim, nr_state
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
if isinstance(tokvecses, np.ndarray):
|
||||
tokvecses = [tokvecses]
|
||||
|
||||
tokvecs = self.model[0].ops.flatten(tokvecses)
|
||||
|
||||
|
@ -395,14 +395,14 @@ cdef class Parser:
|
|||
st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat)
|
||||
self.moves.set_valid(&c_is_valid[i*nr_class], st)
|
||||
vectors = state2vec(token_ids[:next_step.size()])
|
||||
scores = vec2scores(vectors)
|
||||
c_scores = <float*>scores.data
|
||||
for i in range(next_step.size()):
|
||||
st = next_step[i]
|
||||
guess = arg_max_if_valid(
|
||||
&c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class)
|
||||
action = self.moves.c[guess]
|
||||
action.do(st, action.label)
|
||||
scores = vec2scores(vectors)
|
||||
c_scores = <float*>scores.data
|
||||
for i in range(next_step.size()):
|
||||
st = next_step[i]
|
||||
guess = arg_max_if_valid(
|
||||
&c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class)
|
||||
action = self.moves.c[guess]
|
||||
action.do(st, action.label)
|
||||
this_step, next_step = next_step, this_step
|
||||
next_step.clear()
|
||||
for st in this_step:
|
||||
|
|
Loading…
Reference in New Issue