* Fix label prediction in StepwiseState

This commit is contained in:
Matthew Honnibal 2015-08-10 05:05:31 +02:00
parent 2c9753eff2
commit 6116413b47
1 changed files with 21 additions and 1 deletions

View File

@ -172,8 +172,16 @@ cdef class StepwiseState:
return self.parser.moves.move_name(action.move, action.label) return self.parser.moves.move_name(action.move, action.label)
def transition(self, action_name): def transition(self, action_name):
moves = {'S': 0, 'D': 1, 'L': 2, 'R': 3}
if action_name == '_': if action_name == '_':
action_name = self.predict() action_name = self.predict()
if action_name == 'L' or action_name == 'R':
self.predict()
move = moves[action_name]
clas = _arg_max_clas(self.eg.c.scores, move, self.parser.moves.c,
self.eg.c.nr_class)
action = self.parser.moves.c[clas]
else:
action = self.parser.moves.lookup_transition(action_name) action = self.parser.moves.lookup_transition(action_name)
action.do(self.stcls, action.label) action.do(self.stcls, action.label)
@ -181,3 +189,15 @@ cdef class StepwiseState:
if self.stcls.is_final(): if self.stcls.is_final():
self.parser.moves.finalize_state(self.stcls) self.parser.moves.finalize_state(self.stcls)
self.doc.set_parse(self.stcls._sent) self.doc.set_parse(self.stcls._sent)
cdef int _arg_max_clas(const weight_t* scores, int move, const Transition* actions,
int nr_class) except -1:
cdef weight_t score = 0
cdef int mode = -1
cdef int i
for i in range(nr_class):
if actions[i].move == move and (mode == -1 or scores[i] >= score):
mode = actions[i].clas
score = scores[i]
return mode