* Ensure high loss for invalid moves, and fix label reading for arc-eager

This commit is contained in:
Matthew Honnibal 2015-03-08 00:13:20 -05:00
parent f5f15a1ef2
commit fdabd93bfb
1 changed files with 20 additions and 10 deletions

View File

@ -35,14 +35,15 @@ cdef get_cost_func_t[N_MOVES] get_cost_funcs
cdef class ArcEager(TransitionSystem):
@classmethod
def get_labels(cls, gold_parses):
labels = {SHIFT: {0: True}, REDUCE: {0: True}, RIGHT: {0: True},
LEFT: {0: True}, BREAK: {0: True}}
labels = {SHIFT: {'ROOT': True}, REDUCE: {'ROOT': True}, RIGHT: {},
LEFT: {}, BREAK: {'ROOT': True}}
for parse in gold_parses:
for i, (head, label) in enumerate(zip(parse.heads, parse.labels)):
if head > i:
labels[RIGHT][label] = True
else:
labels[LEFT][label] = True
if label != 'ROOT':
if head > i:
labels[RIGHT][label] = True
elif head < i:
labels[LEFT][label] = True
return labels
cdef Transition init_transition(self, int clas, int move, int label) except *:
@ -71,6 +72,8 @@ cdef class ArcEager(TransitionSystem):
if scores[i] > score and is_valid[self.c[i].move]:
best = self.c[i]
score = scores[i]
assert best.clas < self.n_moves
assert score > MIN_SCORE
# Label Shift moves with the best Right-Arc label, for non-monotonic
# actions
if best.move == SHIFT:
@ -85,7 +88,7 @@ cdef class ArcEager(TransitionSystem):
cdef int _do_shift(const Transition* self, State* state) except -1:
# Set the dep label, in case we need it after we reduce
if NON_MONOTONIC:
get_s0(state).dep = self.label
state.sent[state.i].dep = self.label
push_stack(state)
@ -124,7 +127,8 @@ do_funcs[BREAK] = _do_break
cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1:
assert not at_eol(s)
if not _can_shift(s):
return 9000
cost = 0
cost += head_in_stack(s, s.i, gold.c_heads)
cost += children_in_stack(s, s.i, gold.c_heads)
@ -137,7 +141,8 @@ cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) exc
cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) except -1:
assert s.stack_len >= 1
if not _can_right(s):
return 9000
cost = 0
if gold.c_heads[s.i] == s.stack[0]:
cost += self.label != gold.c_labels[s.i]
@ -151,7 +156,8 @@ cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) exc
cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) except -1:
assert s.stack_len >= 1
if not _can_left(s):
return 9000
cost = 0
if gold.c_heads[s.stack[0]] == s.i:
cost += self.label != gold.c_labels[s.stack[0]]
@ -166,6 +172,8 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) exce
cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) except -1:
if not _can_reduce(s):
return 9000
cdef int cost = 0
cost += children_in_buffer(s, s.stack[0], gold.c_heads)
if NON_MONOTONIC:
@ -174,6 +182,8 @@ cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) ex
cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) except -1:
if not _can_break(s):
return 9000
# When we break, we Reduce all of the words on the stack.
cdef int cost = 0
# Number of deps between S0...Sn and N0...Nn