From 0786d9b3c79f271596bbcdeb904056e8272bacec Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Tue, 2 Jun 2015 18:38:07 +0200 Subject: [PATCH] * Refactor TransitionSystem, adding set_valid method --- spacy/syntax/arc_eager.pyx | 255 ++++++++++++++--------------- spacy/syntax/ner.pyx | 5 +- spacy/syntax/transition_system.pxd | 2 +- spacy/syntax/transition_system.pyx | 2 +- 4 files changed, 126 insertions(+), 138 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 946cd540b..7cf2f1d42 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -44,10 +44,6 @@ MOVE_NAMES[CONSTITUENT] = 'C' MOVE_NAMES[ADJUST] = 'A' -cdef do_func_t[N_MOVES] do_funcs -cdef get_cost_func_t[N_MOVES] get_cost_funcs - - cdef class ArcEager(TransitionSystem): @classmethod def get_labels(cls, gold_parses): @@ -107,8 +103,27 @@ cdef class ArcEager(TransitionSystem): t.clas = clas t.move = move t.label = label - t.do = do_funcs[move] - t.get_cost = get_cost_funcs[move] + if move == SHIFT: + t.do = _do_shift + t.get_cost = _shift_cost + elif move == REDUCE: + t.do = _do_reduce + t.get_cost = _reduce_cost + elif move == LEFT: + t.do = _do_left + t.get_cost = _left_cost + elif move == RIGHT: + t.do = _do_right + t.get_cost = _right_cost + elif move == BREAK: + t.get_cost = _break_cost + elif move == CONSTITUENT: + t.get_cost = _constituent_cost + elif move == ADJUST: + t.do = _do_adjust + t.get_cost = _adjust_cost + else: + raise Exception(move) return t cdef int initialize_state(self, State* state) except -1: @@ -120,7 +135,7 @@ cdef class ArcEager(TransitionSystem): if state.sent[i].head == 0 and state.sent[i].dep == 0: state.sent[i].dep = root_label - cdef bint* get_valid(self, const State* s) except NULL: + cdef int set_valid(self, bint* output, const State* s) except -1: cdef bint[N_MOVES] is_valid is_valid[SHIFT] = _can_shift(s) is_valid[REDUCE] = _can_reduce(s) @@ -131,8 +146,7 @@ cdef class ArcEager(TransitionSystem): is_valid[ADJUST] = _can_adjust(s) cdef int i for i in range(self.n_moves): - self._is_valid[i] = is_valid[self.c[i].move] - return self._is_valid + output[i] = is_valid[self.c[i].move] cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: cdef bint[N_MOVES] is_valid @@ -200,52 +214,6 @@ cdef int _do_break(const Transition* self, State* state) except -1: if not at_eol(state): push_stack(state) - -cdef int _do_constituent(const Transition* self, State* state) except -1: - return False - #cdef Constituent* bracket = new_bracket(state.ctnts) - - #bracket.parent = NULL - #bracket.label = self.label - #bracket.head = get_s0(state) - #bracket.length = 0 - - #attach(bracket, state.ctnts.stack) - # Attach rightward children. They're in the brackets array somewhere - # between here and B0. - #cdef Constituent* node - #cdef const TokenC* node_gov - #for i in range(1, bracket - state.ctnts.stack): - # node = bracket - i - # node_gov = node.head + node.head.head - # if node_gov == bracket.head: - # attach(bracket, node) - - -cdef int _do_adjust(const Transition* self, State* state) except -1: - return False - #cdef Constituent* b0 = state.ctnts.stack[0] - #cdef Constituent* b1 = state.ctnts.stack[1] - - #assert (b1.head + b1.head.head) == b0.head - #assert b0.head < b1.head - #assert b0 < b1 - - #attach(b0, b1) - ## Pop B1 from stack, but keep B0 on top - #state.ctnts.stack -= 1 - #state.ctnts.stack[0] = b0 - - -do_funcs[SHIFT] = _do_shift -do_funcs[REDUCE] = _do_reduce -do_funcs[LEFT] = _do_left -do_funcs[RIGHT] = _do_right -do_funcs[BREAK] = _do_break -do_funcs[CONSTITUENT] = _do_constituent -do_funcs[ADJUST] = _do_adjust - - cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1: if not _can_shift(s): return 9000 @@ -257,7 +225,6 @@ cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) exc cost += 1 return cost - cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) except -1: if not _can_right(s): return 9000 @@ -322,6 +289,77 @@ cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) exc return cost +cdef inline bint _can_shift(const State* s) nogil: + return not at_eol(s) + + +cdef inline bint _can_right(const State* s) nogil: + return s.stack_len >= 1 and not at_eol(s) + + +cdef inline bint _can_left(const State* s) nogil: + if NON_MONOTONIC: + return s.stack_len >= 1 #and not missing_brackets(s) + else: + return s.stack_len >= 1 and not has_head(get_s0(s)) + + +cdef inline bint _can_reduce(const State* s) nogil: + if NON_MONOTONIC: + return s.stack_len >= 2 #and not missing_brackets(s) + else: + return s.stack_len >= 2 and has_head(get_s0(s)) + +cdef inline bint _can_break(const State* s) nogil: + cdef int i + if not USE_BREAK: + return False + elif at_eol(s): + return False + #elif NON_MONOTONIC: + # return True + else: + # In the Break transition paper, they have this constraint that prevents + # Break if stack is disconnected. But, if we're doing non-monotonic parsing, + # we prefer to relax this constraint. This is helpful in parsing whole + # documents, because then we don't get stuck with words on the stack. + seen_headless = False + for i in range(s.stack_len): + if s.sent[s.stack[-i]].head == 0: + if seen_headless: + return False + else: + seen_headless = True + # TODO: Constituency constraints + return True + +cdef inline bint _can_constituent(const State* s) nogil: + if s.stack_len < 1: + return False + return False + #else: + # # If all stack elements are popped, can't constituent + # for i in range(s.ctnts.stack_len): + # if not s.ctnts.is_popped[-i]: + # return True + # else: + # return False + +cdef inline bint _can_adjust(const State* s) nogil: + return False + #if s.ctnts.stack_len < 2: + # return False + + #cdef const Constituent* b1 = s.ctnts.stack[-1] + #cdef const Constituent* b0 = s.ctnts.stack[0] + + #if (b1.head + b1.head.head) != b0.head: + # return False + #elif b0.head >= b1.head: + # return False + #elif b0 >= b1: + # return False + cdef int _constituent_cost(const Transition* self, const State* s, GoldParse gold) except -1: if not _can_constituent(s): return 9000 @@ -349,7 +387,6 @@ cdef int _constituent_cost(const Transition* self, const State* s, GoldParse gol # else: # loss = 1 # If we see the start position, set loss to 1 #return loss - cdef int _adjust_cost(const Transition* self, const State* s, GoldParse gold) except -1: if not _can_adjust(s): @@ -383,85 +420,37 @@ cdef int _adjust_cost(const Transition* self, const State* s, GoldParse gold) ex #return loss -get_cost_funcs[SHIFT] = _shift_cost -get_cost_funcs[REDUCE] = _reduce_cost -get_cost_funcs[LEFT] = _left_cost -get_cost_funcs[RIGHT] = _right_cost -get_cost_funcs[BREAK] = _break_cost -get_cost_funcs[CONSTITUENT] = _constituent_cost -get_cost_funcs[ADJUST] = _adjust_cost - - -cdef inline bint _can_shift(const State* s) nogil: - return not at_eol(s) - - -cdef inline bint _can_right(const State* s) nogil: - return s.stack_len >= 1 and not at_eol(s) - - -cdef inline bint _can_left(const State* s) nogil: - if NON_MONOTONIC: - return s.stack_len >= 1 #and not missing_brackets(s) - else: - return s.stack_len >= 1 and not has_head(get_s0(s)) - - -cdef inline bint _can_reduce(const State* s) nogil: - if NON_MONOTONIC: - return s.stack_len >= 2 #and not missing_brackets(s) - else: - return s.stack_len >= 2 and has_head(get_s0(s)) - - -cdef inline bint _can_break(const State* s) nogil: - cdef int i - if not USE_BREAK: - return False - elif at_eol(s): - return False - #elif NON_MONOTONIC: - # return True - else: - # In the Break transition paper, they have this constraint that prevents - # Break if stack is disconnected. But, if we're doing non-monotonic parsing, - # we prefer to relax this constraint. This is helpful in parsing whole - # documents, because then we don't get stuck with words on the stack. - seen_headless = False - for i in range(s.stack_len): - if s.sent[s.stack[-i]].head == 0: - if seen_headless: - return False - else: - seen_headless = True - # TODO: Constituency constraints - return True - - -cdef inline bint _can_constituent(const State* s) nogil: - if s.stack_len < 1: - return False +cdef int _do_constituent(const Transition* self, State* state) except -1: return False - #else: - # # If all stack elements are popped, can't constituent - # for i in range(s.ctnts.stack_len): - # if not s.ctnts.is_popped[-i]: - # return True - # else: - # return False + #cdef Constituent* bracket = new_bracket(state.ctnts) + + #bracket.parent = NULL + #bracket.label = self.label + #bracket.head = get_s0(state) + #bracket.length = 0 + + #attach(bracket, state.ctnts.stack) + # Attach rightward children. They're in the brackets array somewhere + # between here and B0. + #cdef Constituent* node + #cdef const TokenC* node_gov + #for i in range(1, bracket - state.ctnts.stack): + # node = bracket - i + # node_gov = node.head + node.head.head + # if node_gov == bracket.head: + # attach(bracket, node) -cdef inline bint _can_adjust(const State* s) nogil: +cdef int _do_adjust(const Transition* self, State* state) except -1: return False - #if s.ctnts.stack_len < 2: - # return False + #cdef Constituent* b0 = state.ctnts.stack[0] + #cdef Constituent* b1 = state.ctnts.stack[1] - #cdef const Constituent* b1 = s.ctnts.stack[-1] - #cdef const Constituent* b0 = s.ctnts.stack[0] + #assert (b1.head + b1.head.head) == b0.head + #assert b0.head < b1.head + #assert b0 < b1 - #if (b1.head + b1.head.head) != b0.head: - # return False - #elif b0.head >= b1.head: - # return False - #elif b0 >= b1: - # return False + #attach(b0, b1) + ## Pop B1 from stack, but keep B0 on top + #state.ctnts.stack -= 1 + #state.ctnts.stack[0] = b0 diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index 426a715d7..917bab594 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -140,12 +140,11 @@ cdef class BiluoPushDown(TransitionSystem): t.score = score return t - cdef bint* get_valid(self, const State* s) except NULL: + cdef int set_valid(self, bint* output, const State* s) except -1: cdef int i for i in range(self.n_moves): m = &self.c[i] - self._is_valid[i] = _is_valid(m.move, m.label, s) - return self._is_valid + output[i] = _is_valid(m.move, m.label, s) cdef int _get_cost(const Transition* self, const State* s, GoldParse gold) except -1: diff --git a/spacy/syntax/transition_system.pxd b/spacy/syntax/transition_system.pxd index 57f1943b2..0afab9f1a 100644 --- a/spacy/syntax/transition_system.pxd +++ b/spacy/syntax/transition_system.pxd @@ -40,7 +40,7 @@ cdef class TransitionSystem: cdef Transition init_transition(self, int clas, int move, int label) except * - cdef bint* get_valid(self, const State* state) except NULL + cdef int set_valid(self, bint* output, const State* state) except -1 cdef Transition best_valid(self, const weight_t* scores, const State* state) except * diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index 67c33155c..a03620d3b 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -45,7 +45,7 @@ cdef class TransitionSystem: cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: raise NotImplementedError - cdef bint* get_valid(self, const State* state) except NULL: + cdef int set_valid(self, bint* output, const State* state) except -1: raise NotImplementedError cdef Transition best_gold(self, const weight_t* scores, const State* s,