diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index d53a1959a..ed89d4cc8 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -172,12 +172,32 @@ cdef class StepwiseState: return self.parser.moves.move_name(action.move, action.label) def transition(self, action_name): + moves = {'S': 0, 'D': 1, 'L': 2, 'R': 3} if action_name == '_': action_name = self.predict() - action = self.parser.moves.lookup_transition(action_name) + if action_name == 'L' or action_name == 'R': + self.predict() + move = moves[action_name] + clas = _arg_max_clas(self.eg.c.scores, move, self.parser.moves.c, + self.eg.c.nr_class) + action = self.parser.moves.c[clas] + else: + action = self.parser.moves.lookup_transition(action_name) action.do(self.stcls, action.label) def finish(self): if self.stcls.is_final(): self.parser.moves.finalize_state(self.stcls) self.doc.set_parse(self.stcls._sent) + + +cdef int _arg_max_clas(const weight_t* scores, int move, const Transition* actions, + int nr_class) except -1: + cdef weight_t score = 0 + cdef int mode = -1 + cdef int i + for i in range(nr_class): + if actions[i].move == move and (mode == -1 or scores[i] >= score): + mode = actions[i].clas + score = scores[i] + return mode