diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 1fdab2f1f..99fe7f943 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -19,7 +19,7 @@ from .stateclass cimport StateClass DEF NON_MONOTONIC = True -DEF USE_BREAK = True +DEF USE_BREAK = False cdef weight_t MIN_SCORE = -90000 @@ -70,12 +70,14 @@ cdef int pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil: break return cost + cdef int arc_cost(StateClass stcls, const GoldParseC* gold, int head, int child) nogil: if arc_is_gold(gold, head, child): return 0 elif stcls.H(child) == gold.heads[child]: return 1 - elif gold.heads[child] >= stcls.B(0): + # Head in buffer + elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != -1: return 1 else: return 0 @@ -110,13 +112,10 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil: cdef class Shift: @staticmethod cdef bint is_valid(StateClass st, int label) nogil: - return not st.eol() + return st.buffer_length() >= 2 and not st.shifted[st.B(0)] @staticmethod cdef int transition(StateClass state, int label) nogil: - # Set the dep label, in case we need it after we reduce - if NON_MONOTONIC: - state._sent[state.B(0)].dep = label state.push() @staticmethod @@ -135,27 +134,25 @@ cdef class Shift: cdef class Reduce: @staticmethod cdef bint is_valid(StateClass st, int label) nogil: - if NON_MONOTONIC: - return st.stack_depth() >= 2 #and not missing_brackets(s) - else: - return st.stack_depth() >= 2 and st.has_head(st.S(0)) + return st.stack_depth() >= 2 @staticmethod cdef int transition(StateClass st, int label) nogil: - if NON_MONOTONIC and not st.has_head(st.S(0)) and st.stack_depth() >= 2: - st.add_arc(st.S(1), st.S(0), st.S_(0).dep) - st.pop() + if st.has_head(st.S(0)): + st.pop() + else: + st.unshift() @staticmethod cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: return Reduce.move_cost(s, gold) + Reduce.label_cost(s, gold, label) @staticmethod - cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil: - if NON_MONOTONIC: - return pop_cost(s, gold, s.S(0)) + cdef inline int move_cost(StateClass st, const GoldParseC* gold) nogil: + if st.shifted[st.S(0)]: + return pop_cost(st, gold, st.S(0)) else: - return children_in_buffer(s, s.S(0), gold.heads) + return 0 @staticmethod cdef inline int label_cost(StateClass s, const GoldParseC* gold, int label) nogil: @@ -166,18 +163,16 @@ cdef class LeftArc: @staticmethod cdef bint is_valid(StateClass st, int label) nogil: if NON_MONOTONIC: - return st.stack_depth() >= 1 #and not missing_brackets(s) + return st.stack_depth() >= 1 and st.buffer_length() >= 1 #and not missing_brackets(s) else: - return st.stack_depth() >= 1 and not st.has_head(st.S(0)) + return st.stack_depth() >= 1 and st.buffer_length() >= 1 and not st.has_head(st.S(0)) @staticmethod cdef int transition(StateClass st, int label) nogil: - # Interpret left-arcs from EOL as attachment to root - if st.eol(): - st.add_arc(st.S(0), st.S(0), label) - else: - st.add_arc(st.B(0), st.S(0), label) + st.add_arc(st.B(0), st.S(0), label) st.pop() + if st.empty(): + st.push() @staticmethod cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: @@ -198,7 +193,7 @@ cdef class LeftArc: cdef class RightArc: @staticmethod cdef bint is_valid(StateClass st, int label) nogil: - return st.stack_depth() >= 1 and not st.eol() + return st.stack_depth() >= 1 and st.buffer_length() >= 1 @staticmethod cdef int transition(StateClass st, int label) nogil: @@ -213,6 +208,8 @@ cdef class RightArc: cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil: if arc_is_gold(gold, s.S(0), s.B(0)): return 0 + elif s.shifted[s.B(0)]: + return push_cost(s, gold, s.B(0)) else: return push_cost(s, gold, s.B(0)) + arc_cost(s, gold, s.S(0), s.B(0)) @@ -231,30 +228,13 @@ cdef class Break: return False elif st.stack_depth() < 1: return False - elif NON_MONOTONIC: - return True else: - # In the Break transition paper, they have this constraint that prevents - # Break if stack is disconnected. But, if we're doing non-monotonic parsing, - # we prefer to relax this constraint. This is helpful in parsing whole - # documents, because then we don't get stuck with words on the stack. - seen_headless = False - for i in range(st.stack_depth()): - if not st.has_head(st.S(i)): - if seen_headless: - return False - else: - seen_headless = True - # TODO: Constituency constraints return True @staticmethod cdef int transition(StateClass st, int label) nogil: - st.set_sent_end(st.B(0)-1) - while not st.empty(): - if not st.has_head(st.S(0)): - st._sent[st.S(0)].dep = label - st.pop() + #st.set_sent_end() + pass @staticmethod cdef int cost(StateClass s, const GoldParseC* gold, int label) nogil: @@ -262,9 +242,9 @@ cdef class Break: @staticmethod cdef inline int move_cost(StateClass s, const GoldParseC* gold) nogil: - # When we break, we Reduce all of the words on the stack. + # When we break, we can't reach any arcs between stack and buffer + # So cost is number of deps between S0...Sn and B0...Nn cdef int cost = 0 - # Number of deps between S0...Sn and N0...Nn cdef int i, j, B_i, S_i for i in range(s.buffer_length()): B_i = s.B(i) @@ -432,7 +412,7 @@ cdef class ArcEager(TransitionSystem): best = self.c[i] score = scores[i] assert best.clas < self.n_moves - assert score > MIN_SCORE + assert score > MIN_SCORE, (stcls.stack_depth(), stcls.buffer_length()) # Label Shift moves with the best Right-Arc label, for non-monotonic # actions if best.move == SHIFT: