mirror of https://github.com/explosion/spaCy.git
* Fix label prediction in StepwiseState
This commit is contained in:
parent
2c9753eff2
commit
6116413b47
|
@ -172,8 +172,16 @@ cdef class StepwiseState:
|
|||
return self.parser.moves.move_name(action.move, action.label)
|
||||
|
||||
def transition(self, action_name):
|
||||
moves = {'S': 0, 'D': 1, 'L': 2, 'R': 3}
|
||||
if action_name == '_':
|
||||
action_name = self.predict()
|
||||
if action_name == 'L' or action_name == 'R':
|
||||
self.predict()
|
||||
move = moves[action_name]
|
||||
clas = _arg_max_clas(self.eg.c.scores, move, self.parser.moves.c,
|
||||
self.eg.c.nr_class)
|
||||
action = self.parser.moves.c[clas]
|
||||
else:
|
||||
action = self.parser.moves.lookup_transition(action_name)
|
||||
action.do(self.stcls, action.label)
|
||||
|
||||
|
@ -181,3 +189,15 @@ cdef class StepwiseState:
|
|||
if self.stcls.is_final():
|
||||
self.parser.moves.finalize_state(self.stcls)
|
||||
self.doc.set_parse(self.stcls._sent)
|
||||
|
||||
|
||||
cdef int _arg_max_clas(const weight_t* scores, int move, const Transition* actions,
|
||||
int nr_class) except -1:
|
||||
cdef weight_t score = 0
|
||||
cdef int mode = -1
|
||||
cdef int i
|
||||
for i in range(nr_class):
|
||||
if actions[i].move == move and (mode == -1 or scores[i] >= score):
|
||||
mode = actions[i].clas
|
||||
score = scores[i]
|
||||
return mode
|
||||
|
|
Loading…
Reference in New Issue