diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index b7aca26b8..ffd7c8da6 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -427,7 +427,7 @@ cdef class Parser: cuda_stream = get_cuda_stream() - states, golds, max_length = self._init_gold_batch(docs, golds) + states, golds, max_steps = self._init_gold_batch(docs, golds) state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream, 0.0) todo = [(s, g) for (s, g) in zip(states, golds) @@ -438,6 +438,7 @@ cdef class Parser: backprops = [] d_tokvecs = state2vec.ops.allocate(tokvecs.shape) cdef float loss = 0. + n_steps = 0 while todo: states, golds = zip(*todo) @@ -467,7 +468,8 @@ cdef class Parser: todo = [st for st in todo if not st[0].is_final()] if losses is not None: losses[self.name] += (d_scores**2).sum() - if len(backprops) >= (max_length * 2): + n_steps += 1 + if n_steps >= max_steps: break self._make_updates(d_tokvecs, backprops, sgd, cuda_stream) @@ -482,7 +484,8 @@ cdef class Parser: StateClass state Transition action whole_states = self.moves.init_batch(whole_docs) - max_length = max(5, min(20, min([len(doc) for doc in whole_docs]))) + max_length = max(5, min(50, min([len(doc) for doc in whole_docs]))) + max_moves = 0 states = [] golds = [] for doc, state, gold in zip(whole_docs, whole_states, whole_golds): @@ -493,16 +496,20 @@ cdef class Parser: start = 0 while start < len(doc): state = state.copy() + n_moves = 0 while state.B(0) < start and not state.is_final(): action = self.moves.c[oracle_actions.pop(0)] action.do(state.c, action.label) + n_moves += 1 has_gold = self.moves.has_gold(gold, start=start, end=start+max_length) if not state.is_final() and has_gold: states.append(state) golds.append(gold) + max_moves = max(max_moves, n_moves) start += min(max_length, len(doc)-start) - return states, golds, max_length + max_moves = max(max_moves, len(oracle_actions)) + return states, golds, max_moves def _make_updates(self, d_tokvecs, backprops, sgd, cuda_stream=None): # Tells CUDA to block, so our async copies complete.