mirror of https://github.com/explosion/spaCy.git
* Fix label prediction in StepwiseState
This commit is contained in:
parent
2c9753eff2
commit
6116413b47
|
@ -172,8 +172,16 @@ cdef class StepwiseState:
|
||||||
return self.parser.moves.move_name(action.move, action.label)
|
return self.parser.moves.move_name(action.move, action.label)
|
||||||
|
|
||||||
def transition(self, action_name):
|
def transition(self, action_name):
|
||||||
|
moves = {'S': 0, 'D': 1, 'L': 2, 'R': 3}
|
||||||
if action_name == '_':
|
if action_name == '_':
|
||||||
action_name = self.predict()
|
action_name = self.predict()
|
||||||
|
if action_name == 'L' or action_name == 'R':
|
||||||
|
self.predict()
|
||||||
|
move = moves[action_name]
|
||||||
|
clas = _arg_max_clas(self.eg.c.scores, move, self.parser.moves.c,
|
||||||
|
self.eg.c.nr_class)
|
||||||
|
action = self.parser.moves.c[clas]
|
||||||
|
else:
|
||||||
action = self.parser.moves.lookup_transition(action_name)
|
action = self.parser.moves.lookup_transition(action_name)
|
||||||
action.do(self.stcls, action.label)
|
action.do(self.stcls, action.label)
|
||||||
|
|
||||||
|
@ -181,3 +189,15 @@ cdef class StepwiseState:
|
||||||
if self.stcls.is_final():
|
if self.stcls.is_final():
|
||||||
self.parser.moves.finalize_state(self.stcls)
|
self.parser.moves.finalize_state(self.stcls)
|
||||||
self.doc.set_parse(self.stcls._sent)
|
self.doc.set_parse(self.stcls._sent)
|
||||||
|
|
||||||
|
|
||||||
|
cdef int _arg_max_clas(const weight_t* scores, int move, const Transition* actions,
|
||||||
|
int nr_class) except -1:
|
||||||
|
cdef weight_t score = 0
|
||||||
|
cdef int mode = -1
|
||||||
|
cdef int i
|
||||||
|
for i in range(nr_class):
|
||||||
|
if actions[i].move == move and (mode == -1 or scores[i] >= score):
|
||||||
|
mode = actions[i].clas
|
||||||
|
score = scores[i]
|
||||||
|
return mode
|
||||||
|
|
Loading…
Reference in New Issue