diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index acbc4ac87..80cf9deaf 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -18,8 +18,7 @@ cdef enum: REDUCE LEFT RIGHT - BREAK_SHIFT - BREAK_RIGHT + BREAK N_MOVES # Break transition from here @@ -48,35 +47,14 @@ cdef inline bint _can_reduce(const State* s) nogil: return s.stack_len >= 2 and has_head(get_s0(s)) -cdef inline bint _can_break_shift(const State* s) nogil: +cdef inline bint _can_break(const State* s) nogil: cdef int i if not USE_BREAK: return False elif at_eol(s): return False else: - # P. 757 - # In UPP, if Shift(F) or RightArc(F) fail to result in a single parsing - # tree, they cannot be performed as well. - seen_headless = False - for i in range(s.stack_len): - if seen_headless: - return False - else: - seen_headless = True - return True - - -cdef inline bint _can_break_right(const State* s) nogil: - cdef int i - if not USE_BREAK: - return False - elif not _can_right(s): - return False - else: - # P. 757 - # In UPP, if Shift(F) or RightArc(F) fail to result in a single parsing - # tree, they cannot be performed as well. + # If stack is disconnected, cannot break seen_headless = False for i in range(s.stack_len): if s.sent[s.stack[-i]].head == 0: @@ -95,7 +73,7 @@ cdef int _shift_cost(const State* s, const int* gold) except -1: if NON_MONOTONIC: cost += gold[s.stack[0]] == s.i # If we can break, and there's no cost to doing so, we should - if _can_break_shift(s) and _break_shift_cost(s, gold) == 0: + if _can_break(s) and _break_cost(s, gold) == 0: cost += 1 return cost @@ -103,9 +81,6 @@ cdef int _shift_cost(const State* s, const int* gold) except -1: cdef int _right_cost(const State* s, const int* gold) except -1: assert s.stack_len >= 1 cost = 0 - # If we can break, and there's no cost to doing so, we should - if _can_break_right(s) and _break_right_cost(s, gold) == 0: - cost += 1 if gold[s.i] == s.stack[0]: return cost cost += head_in_buffer(s, s.i, gold) @@ -138,11 +113,8 @@ cdef int _reduce_cost(const State* s, const int* gold) except -1: return cost -cdef int _break_shift_cost(const State* s, const int* gold) except -1: - # When we break, we Reduce all of the words on the stack. We also remove - # the first word from the buffer. - # - # n0_cost: +cdef int _break_cost(const State* s, const int* gold) except -1: + # When we break, we Reduce all of the words on the stack. cdef int cost = 0 # Number of deps between S0...Sn and N0...Nn for i in range(s.i, s.sent_len): @@ -151,29 +123,6 @@ cdef int _break_shift_cost(const State* s, const int* gold) except -1: return cost -cdef int _break_right_cost(const State* s, const int* gold) except -1: - cdef int cost = 0 - assert s.stack_len >= 1 - cdef int i - # When we break, we Reduce all of the words on the stack. We also remove - # the first word from the buffer. - # - # n0_cost: - # number of head/child deps between n0 and N0...Nn - cost += children_in_buffer(s, s.i, gold) - cost += head_in_buffer(s, s.i, gold) - # number of child deps from N0 into stack - cost += children_in_stack(s, s.i, gold) - # number of head deps to N0 from S1..Sn - for i in range(1, s.stack_len): - cost += s.stack[-i] == gold[s.i] - # Number of deps between S0...Sn and N1...Nn - for i in range(s.i+1, s.sent_len): - cost += children_in_stack(s, i, gold) - cost += head_in_stack(s, i, gold) - return cost - - cdef class TransitionSystem: def __init__(self, list left_labels, list right_labels): self.mem = Pool() @@ -183,7 +132,7 @@ cdef class TransitionSystem: right_labels.pop(right_labels.index('ROOT')) if 'ROOT' in left_labels: left_labels.pop(left_labels.index('ROOT')) - self.n_moves = 3 + len(left_labels) + len(right_labels) + len(right_labels) + self.n_moves = 3 + len(left_labels) + len(right_labels) moves = self.mem.alloc(self.n_moves, sizeof(Transition)) cdef int i = 0 moves[i].move = SHIFT @@ -210,17 +159,10 @@ cdef class TransitionSystem: moves[i].label = label_id moves[i].clas = i i += 1 - moves[i].move = BREAK_SHIFT + moves[i].move = BREAK moves[i].label = 0 moves[i].clas = i i += 1 - for label_str in right_labels: - label_str = unicode(label_str) - label_id = self.label_ids.setdefault(label_str, len(self.label_ids)) - moves[i].move = BREAK_RIGHT - moves[i].label = label_id - moves[i].clas = i - i += 1 self._moves = moves cdef int transition(self, State *s, const Transition* t) except -1: @@ -239,16 +181,10 @@ cdef class TransitionSystem: # TODO: Huh? Is this some weirdness from the non-monotonic? add_dep(s, s.stack[-1], s.stack[0], get_s0(s).dep) pop_stack(s) - elif t.move == BREAK_RIGHT: - add_dep(s, s.stack[0], s.i, t.label) - push_stack(s) - while s.stack_len != 0: - s.stack -= 1 - s.stack_len -= 1 - if not at_eol(s): - push_stack(s) - elif t.move == BREAK_SHIFT: + elif t.move == BREAK: while s.stack_len != 0: + if get_s0(s).head == 0: + get_s0(s).dep = 0 s.stack -= 1 s.stack_len -= 1 if not at_eol(s): @@ -262,8 +198,7 @@ cdef class TransitionSystem: valid[LEFT] = _can_left(s) valid[RIGHT] = _can_right(s) valid[REDUCE] = _can_reduce(s) - valid[BREAK_SHIFT] = _can_break_shift(s) - valid[BREAK_RIGHT] = _can_break_right(s) + valid[BREAK] = _can_break(s) cdef int best = -1 cdef weight_t score = 0 @@ -292,8 +227,7 @@ cdef class TransitionSystem: unl_costs[LEFT] = _left_cost(s, gold_heads) if _can_left(s) else -1 unl_costs[RIGHT] = _right_cost(s, gold_heads) if _can_right(s) else -1 unl_costs[REDUCE] = _reduce_cost(s, gold_heads) if _can_reduce(s) else -1 - unl_costs[BREAK_SHIFT] = _break_shift_cost(s, gold_heads) if _can_break_shift(s) else -1 - unl_costs[BREAK_RIGHT] = _break_right_cost(s, gold_heads) if _can_break_right(s) else -1 + unl_costs[BREAK] = _break_cost(s, gold_heads) if _can_break(s) else -1 guess.cost = unl_costs[guess.move] cdef Transition t @@ -309,7 +243,7 @@ cdef class TransitionSystem: return t elif gold_heads[s.i] == s.stack[0]: target_label = gold_labels[s.i] - if guess.move == RIGHT or guess.move == BREAK_RIGHT: + if guess.move == RIGHT: if unl_costs[guess.move] != 0: guess.cost += guess.label != target_label for i in range(self.n_moves):