* Ensure high loss for invalid moves, and fix label reading for arc-eager

This commit is contained in:
Matthew Honnibal 2015-03-08 00:13:20 -05:00
parent f5f15a1ef2
commit fdabd93bfb
1 changed files with 20 additions and 10 deletions

View File

@ -35,13 +35,14 @@ cdef get_cost_func_t[N_MOVES] get_cost_funcs
cdef class ArcEager(TransitionSystem): cdef class ArcEager(TransitionSystem):
@classmethod @classmethod
def get_labels(cls, gold_parses): def get_labels(cls, gold_parses):
labels = {SHIFT: {0: True}, REDUCE: {0: True}, RIGHT: {0: True}, labels = {SHIFT: {'ROOT': True}, REDUCE: {'ROOT': True}, RIGHT: {},
LEFT: {0: True}, BREAK: {0: True}} LEFT: {}, BREAK: {'ROOT': True}}
for parse in gold_parses: for parse in gold_parses:
for i, (head, label) in enumerate(zip(parse.heads, parse.labels)): for i, (head, label) in enumerate(zip(parse.heads, parse.labels)):
if label != 'ROOT':
if head > i: if head > i:
labels[RIGHT][label] = True labels[RIGHT][label] = True
else: elif head < i:
labels[LEFT][label] = True labels[LEFT][label] = True
return labels return labels
@ -71,6 +72,8 @@ cdef class ArcEager(TransitionSystem):
if scores[i] > score and is_valid[self.c[i].move]: if scores[i] > score and is_valid[self.c[i].move]:
best = self.c[i] best = self.c[i]
score = scores[i] score = scores[i]
assert best.clas < self.n_moves
assert score > MIN_SCORE
# Label Shift moves with the best Right-Arc label, for non-monotonic # Label Shift moves with the best Right-Arc label, for non-monotonic
# actions # actions
if best.move == SHIFT: if best.move == SHIFT:
@ -85,7 +88,7 @@ cdef class ArcEager(TransitionSystem):
cdef int _do_shift(const Transition* self, State* state) except -1: cdef int _do_shift(const Transition* self, State* state) except -1:
# Set the dep label, in case we need it after we reduce # Set the dep label, in case we need it after we reduce
if NON_MONOTONIC: if NON_MONOTONIC:
get_s0(state).dep = self.label state.sent[state.i].dep = self.label
push_stack(state) push_stack(state)
@ -124,7 +127,8 @@ do_funcs[BREAK] = _do_break
cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1: cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1:
assert not at_eol(s) if not _can_shift(s):
return 9000
cost = 0 cost = 0
cost += head_in_stack(s, s.i, gold.c_heads) cost += head_in_stack(s, s.i, gold.c_heads)
cost += children_in_stack(s, s.i, gold.c_heads) cost += children_in_stack(s, s.i, gold.c_heads)
@ -137,7 +141,8 @@ cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) exc
cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) except -1: cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) except -1:
assert s.stack_len >= 1 if not _can_right(s):
return 9000
cost = 0 cost = 0
if gold.c_heads[s.i] == s.stack[0]: if gold.c_heads[s.i] == s.stack[0]:
cost += self.label != gold.c_labels[s.i] cost += self.label != gold.c_labels[s.i]
@ -151,7 +156,8 @@ cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) exc
cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) except -1: cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) except -1:
assert s.stack_len >= 1 if not _can_left(s):
return 9000
cost = 0 cost = 0
if gold.c_heads[s.stack[0]] == s.i: if gold.c_heads[s.stack[0]] == s.i:
cost += self.label != gold.c_labels[s.stack[0]] cost += self.label != gold.c_labels[s.stack[0]]
@ -166,6 +172,8 @@ cdef int _left_cost(const Transition* self, const State* s, GoldParse gold) exce
cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) except -1: cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) except -1:
if not _can_reduce(s):
return 9000
cdef int cost = 0 cdef int cost = 0
cost += children_in_buffer(s, s.stack[0], gold.c_heads) cost += children_in_buffer(s, s.stack[0], gold.c_heads)
if NON_MONOTONIC: if NON_MONOTONIC:
@ -174,6 +182,8 @@ cdef int _reduce_cost(const Transition* self, const State* s, GoldParse gold) ex
cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) except -1: cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) except -1:
if not _can_break(s):
return 9000
# When we break, we Reduce all of the words on the stack. # When we break, we Reduce all of the words on the stack.
cdef int cost = 0 cdef int cost = 0
# Number of deps between S0...Sn and N0...Nn # Number of deps between S0...Sn and N0...Nn