From a2d6b195dbb990767f93488bfae54ee98e891f50 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 28 Jan 2015 03:09:45 +1100 Subject: [PATCH] * Add messy Break transitions, carefully following the scheme of Dd Zhang et al (2013) --- spacy/syntax/arc_eager.pyx | 101 +++++++++++++++++++++++++++++++++++-- 1 file changed, 98 insertions(+), 3 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index f9ae320e5..55c20eb33 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -8,7 +8,9 @@ from ._state cimport head_in_stack, children_in_stack from ..structs cimport TokenC + DEF NON_MONOTONIC = True +DEF USE_BREAK = True cdef enum: @@ -16,6 +18,8 @@ cdef enum: REDUCE LEFT RIGHT + BREAK_SHIFT + BREAK_RIGHT N_MOVES @@ -41,6 +45,43 @@ cdef inline bint _can_reduce(const State* s) nogil: return s.stack_len >= 2 and has_head(get_s0(s)) +cdef inline bint _can_break_shift(const State* s) nogil: + cdef int i + if not USE_BREAK: + return False + elif not _can_shift(s): + return False + else: + # P. 757 + # In UPP, if Shift(F) or RightArc(F) fail to result in a single parsing + # tree, they cannot be performed as well. + seen_headless = False + for i in range(s.stack_len): + if s.sent[s.stack[i]].head == 0: + return False + return True + + +cdef inline bint _can_break_right(const State* s) nogil: + cdef int i + if not USE_BREAK: + return False + elif not _can_right(s): + return False + else: + # P. 757 + # In UPP, if Shift(F) or RightArc(F) fail to result in a single parsing + # tree, they cannot be performed as well. + seen_headless = False + for i in range(s.stack_len): + if s.sent[s.stack[i]].head == 0: + if seen_headless: + return False + else: + seen_headless = True + return True + + cdef int _shift_cost(const State* s, const int* gold) except -1: assert not at_eol(s) cost = 0 @@ -85,6 +126,28 @@ cdef int _reduce_cost(const State* s, const int* gold) except -1: return cost +cdef int _break_shift_cost(const State* s, const int* gold) except -1: + cdef int cost = _shift_cost(s, gold) + # When we break, we Reduce all of the words on the stack. So, the Break + # cost is the sum of the Reduce costs + for i in range(s.stack_len): + cost += children_in_buffer(s, s.stack[i], gold) + if NON_MONOTONIC: + cost += head_in_buffer(s, s.stack[i], gold) + return cost + + +cdef int _break_right_cost(const State* s, const int* gold) except -1: + cdef int cost = _right_cost(s, gold) + # When we break, we Reduce all of the words on the stack. So, the Break + # cost is the sum of the Reduce costs + for i in range(s.stack_len): + cost += children_in_buffer(s, s.stack[i], gold) + if NON_MONOTONIC: + cost += head_in_buffer(s, s.stack[i], gold) + return cost + + cdef class TransitionSystem: def __init__(self, list left_labels, list right_labels): self.mem = Pool() @@ -94,7 +157,7 @@ cdef class TransitionSystem: right_labels.pop(right_labels.index('ROOT')) if 'ROOT' in left_labels: left_labels.pop(left_labels.index('ROOT')) - self.n_moves = 2 + len(left_labels) + len(right_labels) + self.n_moves = 3 + len(left_labels) + len(right_labels) + len(right_labels) moves = self.mem.alloc(self.n_moves, sizeof(Transition)) cdef int i = 0 moves[i].move = SHIFT @@ -121,6 +184,17 @@ cdef class TransitionSystem: moves[i].label = label_id moves[i].clas = i i += 1 + moves[i].move = BREAK_SHIFT + moves[i].label = 0 + moves[i].clas = i + i += 1 + for label_str in right_labels: + label_str = unicode(label_str) + label_id = self.label_ids.setdefault(label_str, len(self.label_ids)) + moves[i].move = BREAK_RIGHT + moves[i].label = label_id + moves[i].clas = i + i += 1 self._moves = moves cdef int transition(self, State *s, const Transition* t) except -1: @@ -138,6 +212,23 @@ cdef class TransitionSystem: elif t.move == REDUCE: add_dep(s, s.stack[-1], s.stack[0], get_s0(s).dep) pop_stack(s) + elif t.move == BREAK_RIGHT: + add_dep(s, s.stack[0], s.i, t.label) + push_stack(s) + while s.stack_len != 0: + if not has_head(get_s0(s)): + get_s0(s).dep = 0 + s.stack -= 1 + s.stack_len -= 1 + if not at_eol(s): + push_stack(s) + elif t.move == BREAK_SHIFT: + push_stack(s) + get_s0(s).dep = 0 + s.stack -= s.stack_len + s.stack_len = 0 + if not at_eol(s): + push_stack(s) else: raise Exception(t.move) @@ -147,6 +238,8 @@ cdef class TransitionSystem: valid[LEFT] = _can_left(s) valid[RIGHT] = _can_right(s) valid[REDUCE] = _can_reduce(s) + valid[BREAK_SHIFT] = _can_break_shift(s) + valid[BREAK_RIGHT] = _can_break_right(s) cdef int best = -1 cdef weight_t score = 0 @@ -175,6 +268,8 @@ cdef class TransitionSystem: unl_costs[LEFT] = _left_cost(s, gold_heads) if _can_left(s) else -1 unl_costs[RIGHT] = _right_cost(s, gold_heads) if _can_right(s) else -1 unl_costs[REDUCE] = _reduce_cost(s, gold_heads) if _can_reduce(s) else -1 + unl_costs[BREAK_SHIFT] = _break_shift_cost(s, gold_heads) if _can_break_shift(s) else -1 + unl_costs[BREAK_RIGHT] = _break_right_cost(s, gold_heads) if _can_break_right(s) else -1 guess.cost = unl_costs[guess.move] cdef Transition t @@ -190,11 +285,11 @@ cdef class TransitionSystem: return t elif gold_heads[s.i] == s.stack[0]: target_label = gold_labels[s.i] - if guess.move == RIGHT: + if guess.move == RIGHT or guess.move == BREAK_RIGHT: guess.cost += guess.label != target_label for i in range(self.n_moves): t = self._moves[i] - if t.move == RIGHT and t.label == target_label: + if (t.move == RIGHT or t.move == BREAK_RIGHT) and t.label == target_label: return t cdef int best = -1