From 36a34d544bbb57b1da498a5513fa1675d4cfd98a Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Thu, 4 Jun 2015 22:43:03 +0200 Subject: [PATCH] * Refactoring arc_eager, grouping oracle functions into transitions --- spacy/syntax/arc_eager.pyx | 513 +++++++++++++++++++------------------ 1 file changed, 267 insertions(+), 246 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index ff480de40..5576d0330 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -180,7 +180,6 @@ cdef class ArcEager(TransitionSystem): label = labels[s.stack[0]] output[i] += move.label != label and label != -1 - cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: cdef bint[N_MOVES] is_valid is_valid[SHIFT] = _can_shift(s) @@ -210,166 +209,181 @@ cdef class ArcEager(TransitionSystem): return best -cdef int _do_shift(const Transition* self, State* state) except -1: - # Set the dep label, in case we need it after we reduce - if NON_MONOTONIC: - state.sent[state.i].dep = self.label - push_stack(state) +cdef class Shift: + @staticmethod + cdef inline bint is_valid(const State* s) nogil: + return not at_eol(s) - -cdef int _do_left(const Transition* self, State* state) except -1: - # Interpret left-arcs from EOL as attachment to root - if at_eol(state): - add_dep(state, state.stack[0], state.stack[0], self.label) - else: - add_dep(state, state.i, state.stack[0], self.label) - pop_stack(state) - - -cdef int _do_right(const Transition* self, State* state) except -1: - add_dep(state, state.stack[0], state.i, self.label) - push_stack(state) - - -cdef int _do_reduce(const Transition* self, State* state) except -1: - if NON_MONOTONIC and not has_head(get_s0(state)): - add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep) - pop_stack(state) - - -cdef int _do_break(const Transition* self, State* state) except -1: - state.sent[state.i-1].sent_end = True - while state.stack_len != 0: - if get_s0(state).head == 0: - get_s0(state).dep = self.label - state.stack -= 1 - state.stack_len -= 1 - if not at_eol(state): + @staticmethod + cdef int transition(State* state, int label) except -1: + # Set the dep label, in case we need it after we reduce + if NON_MONOTONIC: + state.sent[state.i].dep = label push_stack(state) -cdef int _shift_cost(const Transition* self, const State* s, GoldParseC* gold) except -1: - if not _can_shift(s): - return 9000 - cost = 0 - cost += head_in_stack(s, s.i, gold.heads) - cost += children_in_stack(s, s.i, gold.heads) - # If we can break, and there's no cost to doing so, we should - if _can_break(s) and _break_cost(self, s, gold) == 0: - cost += 1 - return cost - -cdef int _right_cost(const Transition* self, const State* s, GoldParseC* gold) except -1: - if not _can_right(s): - return 9000 - cost = 0 - if gold.heads[s.i] == s.stack[0]: - cost += self.label != -1 and self.label != gold.labels[s.i] - return cost - # This indicates missing head - if gold.labels[s.i] != -1: - cost += head_in_buffer(s, s.i, gold.heads) - cost += children_in_stack(s, s.i, gold.heads) - cost += head_in_stack(s, s.i, gold.heads) - return cost - - -cdef int _left_cost(const Transition* self, const State* s, GoldParseC* gold) except -1: - if not _can_left(s): - return 9000 - cost = 0 - if gold.heads[s.stack[0]] == s.i: - cost += self.label != -1 and self.label != gold.labels[s.stack[0]] - return cost - # If we're at EOL, then the left arc will add an arc to ROOT. - elif at_eol(s): - # Are we root? - if gold.labels[s.stack[0]] != -1: - # If we're at EOL, prefer to reduce or break over left-arc - if _can_reduce(s) or _can_break(s): - cost += gold.heads[s.stack[0]] != s.stack[0] - # Are we labelling correctly? - cost += self.label != -1 and self.label != gold.labels[s.stack[0]] + @staticmethod + cdef int cost(const State* s, GoldParseC* gold, int label) except -1: + if not _can_shift(s): + return 9000 + cost = 0 + cost += head_in_stack(s, s.i, gold.heads) + cost += children_in_stack(s, s.i, gold.heads) + # If we can break, and there's no cost to doing so, we should + if _can_break(s) and _break_cost(self, s, gold) == 0: + cost += 1 return cost - cost += head_in_buffer(s, s.stack[0], gold.heads) - cost += children_in_buffer(s, s.stack[0], gold.heads) - if NON_MONOTONIC and s.stack_len >= 2: - cost += gold.heads[s.stack[0]] == s.stack[-1] - if gold.labels[s.stack[0]] != -1: - cost += gold.heads[s.stack[0]] == s.stack[0] - return cost + +cdef class Reduce: + @staticmethod + cdef inline bint is_valid(const State* s) nogil: + if NON_MONOTONIC: + return s.stack_len >= 2 #and not missing_brackets(s) + else: + return s.stack_len >= 2 and has_head(get_s0(s)) + + @staticmethod + cdef int transition(State* state, int label) except -1: + if NON_MONOTONIC and not has_head(get_s0(state)): + add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep) + pop_stack(state) + + @staticmethod + cdef int cost(const State* s, GoldParseC* gold, int label) except -1: + if not Reduce.is_valid(s): + return 9000 + cdef int cost = 0 + cost += children_in_buffer(s, s.stack[0], gold.heads) + if NON_MONOTONIC: + cost += head_in_buffer(s, s.stack[0], gold.heads) + return cost -cdef int _reduce_cost(const Transition* self, const State* s, GoldParseC* gold) except -1: - if not _can_reduce(s): - return 9000 - cdef int cost = 0 - cost += children_in_buffer(s, s.stack[0], gold.heads) - if NON_MONOTONIC: +cdef class LeftArc: + @staticmethod + cdef inline bint is_valid(const State* s) nogil: + if NON_MONOTONIC: + return s.stack_len >= 1 #and not missing_brackets(s) + else: + return s.stack_len >= 1 and not has_head(get_s0(s)) + + @staticmethod + cdef int transition(State* state, int label) except -1: + # Interpret left-arcs from EOL as attachment to root + if at_eol(state): + add_dep(state, state.stack[0], state.stack[0], label) + else: + add_dep(state, state.i, state.stack[0], label) + pop_stack(state) + + @staticmethod + cdef int cost(const State* s, GoldParseC* gold, int label) except -1: + if not _can_left(s): + return 9000 + cost = 0 + if gold.heads[s.stack[0]] == s.i: + cost += self.label != -1 and self.label != gold.labels[s.stack[0]] + return cost + # If we're at EOL, then the left arc will add an arc to ROOT. + elif at_eol(s): + # Are we root? + if gold.labels[s.stack[0]] != -1: + # If we're at EOL, prefer to reduce or break over left-arc + if _can_reduce(s) or _can_break(s): + cost += gold.heads[s.stack[0]] != s.stack[0] + # Are we labelling correctly? + cost += label != -1 and label != gold.labels[s.stack[0]] + return cost cost += head_in_buffer(s, s.stack[0], gold.heads) - return cost + cost += children_in_buffer(s, s.stack[0], gold.heads) + if NON_MONOTONIC and s.stack_len >= 2: + cost += gold.heads[s.stack[0]] == s.stack[-1] + if gold.labels[s.stack[0]] != -1: + cost += gold.heads[s.stack[0]] == s.stack[0] + return cost -cdef int _break_cost(const Transition* self, const State* s, GoldParseC* gold) except -1: - if not _can_break(s): - return 9000 - # When we break, we Reduce all of the words on the stack. - cdef int cost = 0 - # Number of deps between S0...Sn and N0...Nn - for i in range(s.i, s.sent_len): - cost += children_in_stack(s, i, gold.heads) - cost += head_in_stack(s, i, gold.heads) - return cost +cdef class RightArc: + @staticmethod + cdef inline bint is_valid(const State* s) nogil: + return s.stack_len >= 1 and not at_eol(s) + + @staticmethod + cdef int transition(State* state, int label) except -1: + add_dep(state, state.stack[0], state.i, label) + push_stack(state) + + @staticmethod + cdef int cost(const State* s, GoldParseC* gold, int label) except -1: + if not RightArc.is_valid(s): + return 9000 + cost = 0 + if gold.heads[s.i] == s.stack[0]: + cost += label != -1 and self.label != gold.labels[s.i] + return cost + # This indicates missing head + if gold.labels[s.i] != -1: + cost += head_in_buffer(s, s.i, gold.heads) + cost += children_in_stack(s, s.i, gold.heads) + cost += head_in_stack(s, s.i, gold.heads) + return cost -cdef inline bint _can_shift(const State* s) nogil: - return not at_eol(s) +cdef class Break: + @staticmethod + cdef inline bint is_valid(const State* s) nogil: + cdef int i + if not USE_BREAK: + return False + elif at_eol(s): + return False + #elif NON_MONOTONIC: + # return True + else: + # In the Break transition paper, they have this constraint that prevents + # Break if stack is disconnected. But, if we're doing non-monotonic parsing, + # we prefer to relax this constraint. This is helpful in parsing whole + # documents, because then we don't get stuck with words on the stack. + seen_headless = False + for i in range(s.stack_len): + if s.sent[s.stack[-i]].head == 0: + if seen_headless: + return False + else: + seen_headless = True + # TODO: Constituency constraints + return True + + @staticmethod + cdef int transition(State* state, int label) except -1: + state.sent[state.i-1].sent_end = True + while state.stack_len != 0: + if get_s0(state).head == 0: + get_s0(state).dep = label + state.stack -= 1 + state.stack_len -= 1 + if not at_eol(state): + push_stack(state) + + @staticmethod + cdef int cost(const State* s, GoldParseC* gold, int label) except -1: + if not Break.is_valid(s): + return 9000 + # When we break, we Reduce all of the words on the stack. + cdef int cost = 0 + # Number of deps between S0...Sn and N0...Nn + for i in range(s.i, s.sent_len): + cost += children_in_stack(s, i, gold.heads) + cost += head_in_stack(s, i, gold.heads) + return cost -cdef inline bint _can_right(const State* s) nogil: - return s.stack_len >= 1 and not at_eol(s) - - -cdef inline bint _can_left(const State* s) nogil: - if NON_MONOTONIC: - return s.stack_len >= 1 #and not missing_brackets(s) - else: - return s.stack_len >= 1 and not has_head(get_s0(s)) - - -cdef inline bint _can_reduce(const State* s) nogil: - if NON_MONOTONIC: - return s.stack_len >= 2 #and not missing_brackets(s) - else: - return s.stack_len >= 2 and has_head(get_s0(s)) - -cdef inline bint _can_break(const State* s) nogil: - cdef int i - if not USE_BREAK: +cdef class Constituent: + @staticmethod + cdef inline bint is_valid(const State* s) nogil: + if s.stack_len < 1: + return False return False - elif at_eol(s): - return False - #elif NON_MONOTONIC: - # return True - else: - # In the Break transition paper, they have this constraint that prevents - # Break if stack is disconnected. But, if we're doing non-monotonic parsing, - # we prefer to relax this constraint. This is helpful in parsing whole - # documents, because then we don't get stuck with words on the stack. - seen_headless = False - for i in range(s.stack_len): - if s.sent[s.stack[-i]].head == 0: - if seen_headless: - return False - else: - seen_headless = True - # TODO: Constituency constraints - return True - -cdef inline bint _can_constituent(const State* s) nogil: - if s.stack_len < 1: - return False - return False #else: # # If all stack elements are popped, can't constituent # for i in range(s.ctnts.stack_len): @@ -378,112 +392,119 @@ cdef inline bint _can_constituent(const State* s) nogil: # else: # return False -cdef inline bint _can_adjust(const State* s) nogil: - return False - #if s.ctnts.stack_len < 2: - # return False + @staticmethod + cdef int transition(State* state, int label) except -1: + return False + #cdef Constituent* bracket = new_bracket(state.ctnts) - #cdef const Constituent* b1 = s.ctnts.stack[-1] - #cdef const Constituent* b0 = s.ctnts.stack[0] + #bracket.parent = NULL + #bracket.label = self.label + #bracket.head = get_s0(state) + #bracket.length = 0 - #if (b1.head + b1.head.head) != b0.head: - # return False - #elif b0.head >= b1.head: - # return False - #elif b0 >= b1: - # return False + #attach(bracket, state.ctnts.stack) + # Attach rightward children. They're in the brackets array somewhere + # between here and B0. + #cdef Constituent* node + #cdef const TokenC* node_gov + #for i in range(1, bracket - state.ctnts.stack): + # node = bracket - i + # node_gov = node.head + node.head.head + # if node_gov == bracket.head: + # attach(bracket, node) -cdef int _constituent_cost(const Transition* self, const State* s, GoldParseC* gold) except -1: - if not _can_constituent(s): - return 9000 - raise Exception("Constituent move should be disabled currently") - # The gold standard is indexed by end, then by start, then a set of labels - #brackets = gold.brackets(get_s0(s).r_edge, {}) - #if not brackets: - # return 2 # 2 loss for bad bracket, only 1 for good bracket bad label - # Index the current brackets in the state - #existing = set() - #for i in range(s.ctnt_len): - # if ctnt.end == s.r_edge and ctnt.label == self.label: - # existing.add(ctnt.start) - #cdef int loss = 2 - #cdef const TokenC* child - #cdef const TokenC* s0 = get_s0(s) - #cdef int n_left = count_left_kids(s0) - # Iterate over the possible start positions, and check whether we have a - # (start, end, label) match to the gold tree - #for i in range(1, n_left): - # child = get_left(s, s0, i) - # if child.l_edge in brackets and child.l_edge not in existing: - # if self.label in brackets[child.l_edge] - # return 0 - # else: - # loss = 1 # If we see the start position, set loss to 1 - #return loss - -cdef int _adjust_cost(const Transition* self, const State* s, GoldParseC* gold) except -1: - if not _can_adjust(s): - return 9000 - raise Exception("Adjust move should be disabled currently") - # The gold standard is indexed by end, then by start, then a set of labels - #gold_starts = gold.brackets(get_s0(s).r_edge, {}) - # Case 1: There are 0 brackets ending at this word. - # --> Cost is sunk, but must allow brackets to begin - #if not gold_starts: - # return 0 - # Is the top bracket correct? - #gold_labels = gold_starts.get(s.ctnt.start, set()) - # TODO: Case where we have a unary rule - # TODO: Case where two brackets end on this word, with top bracket starting - # before - - #cdef const TokenC* child - #cdef const TokenC* s0 = get_s0(s) - #cdef int n_left = count_left_kids(s0) - #cdef int i - # Iterate over the possible start positions, and check whether we have a - # (start, end, label) match to the gold tree - #for i in range(1, n_left): - # child = get_left(s, s0, i) - # if child.l_edge in brackets: - # if self.label in brackets[child.l_edge]: - # return 0 - # else: - # loss = 1 # If we see the start position, set loss to 1 - #return loss + @staticmethod + cdef int cost(const State* s, GoldParseC* gold, int label) except -1: + if not Constituent.is_valid(s): + return 9000 + raise Exception("Constituent move should be disabled currently") + # The gold standard is indexed by end, then by start, then a set of labels + #brackets = gold.brackets(get_s0(s).r_edge, {}) + #if not brackets: + # return 2 # 2 loss for bad bracket, only 1 for good bracket bad label + # Index the current brackets in the state + #existing = set() + #for i in range(s.ctnt_len): + # if ctnt.end == s.r_edge and ctnt.label == self.label: + # existing.add(ctnt.start) + #cdef int loss = 2 + #cdef const TokenC* child + #cdef const TokenC* s0 = get_s0(s) + #cdef int n_left = count_left_kids(s0) + # Iterate over the possible start positions, and check whether we have a + # (start, end, label) match to the gold tree + #for i in range(1, n_left): + # child = get_left(s, s0, i) + # if child.l_edge in brackets and child.l_edge not in existing: + # if self.label in brackets[child.l_edge] + # return 0 + # else: + # loss = 1 # If we see the start position, set loss to 1 + #return loss -cdef int _do_constituent(const Transition* self, State* state) except -1: - return False - #cdef Constituent* bracket = new_bracket(state.ctnts) +cdef class Adjust: + @staticmethod + cdef inline bint is_valid(const State* s) nogil: + return False + #if s.ctnts.stack_len < 2: + # return False - #bracket.parent = NULL - #bracket.label = self.label - #bracket.head = get_s0(state) - #bracket.length = 0 + #cdef const Constituent* b1 = s.ctnts.stack[-1] + #cdef const Constituent* b0 = s.ctnts.stack[0] - #attach(bracket, state.ctnts.stack) - # Attach rightward children. They're in the brackets array somewhere - # between here and B0. - #cdef Constituent* node - #cdef const TokenC* node_gov - #for i in range(1, bracket - state.ctnts.stack): - # node = bracket - i - # node_gov = node.head + node.head.head - # if node_gov == bracket.head: - # attach(bracket, node) + #if (b1.head + b1.head.head) != b0.head: + # return False + #elif b0.head >= b1.head: + # return False + #elif b0 >= b1: + # return False + + @staticmethod + cdef int transition(State* state) except -1: + return False + #cdef Constituent* b0 = state.ctnts.stack[0] + #cdef Constituent* b1 = state.ctnts.stack[1] + + #assert (b1.head + b1.head.head) == b0.head + #assert b0.head < b1.head + #assert b0 < b1 + + #attach(b0, b1) + ## Pop B1 from stack, but keep B0 on top + #state.ctnts.stack -= 1 + #state.ctnts.stack[0] = b0 + + @staticmethod + cdef int cost(const State* s, GoldParseC* gold, int label) except -1: + if not Adjust.is_valid(s): + return 9000 + raise Exception("Adjust move should be disabled currently") + # The gold standard is indexed by end, then by start, then a set of labels + #gold_starts = gold.brackets(get_s0(s).r_edge, {}) + # Case 1: There are 0 brackets ending at this word. + # --> Cost is sunk, but must allow brackets to begin + #if not gold_starts: + # return 0 + # Is the top bracket correct? + #gold_labels = gold_starts.get(s.ctnt.start, set()) + # TODO: Case where we have a unary rule + # TODO: Case where two brackets end on this word, with top bracket starting + # before + + #cdef const TokenC* child + #cdef const TokenC* s0 = get_s0(s) + #cdef int n_left = count_left_kids(s0) + #cdef int i + # Iterate over the possible start positions, and check whether we have a + # (start, end, label) match to the gold tree + #for i in range(1, n_left): + # child = get_left(s, s0, i) + # if child.l_edge in brackets: + # if self.label in brackets[child.l_edge]: + # return 0 + # else: + # loss = 1 # If we see the start position, set loss to 1 + #return loss -cdef int _do_adjust(const Transition* self, State* state) except -1: - return False - #cdef Constituent* b0 = state.ctnts.stack[0] - #cdef Constituent* b1 = state.ctnts.stack[1] - - #assert (b1.head + b1.head.head) == b0.head - #assert b0.head < b1.head - #assert b0 < b1 - - #attach(b0, b1) - ## Pop B1 from stack, but keep B0 on top - #state.ctnts.stack -= 1 - #state.ctnts.stack[0] = b0