diff --git a/spacy/syntax/_state.pxd b/spacy/syntax/_state.pxd index 1a4b4829f..d60161a98 100644 --- a/spacy/syntax/_state.pxd +++ b/spacy/syntax/_state.pxd @@ -63,53 +63,53 @@ cdef cppclass StateC: free(this._stack - PADDING) free(this.shifted - PADDING) - int S(int i) nogil: + int S(int i) nogil const: if i >= this._s_i: return -1 return this._stack[this._s_i - (i+1)] - int B(int i) nogil: + int B(int i) nogil const: if (i + this._b_i) >= this.length: return -1 return this._buffer[this._b_i + i] - const TokenC* S_(int i) nogil: + const TokenC* S_(int i) nogil const: return this.safe_get(this.S(i)) - const TokenC* B_(int i) nogil: + const TokenC* B_(int i) nogil const: return this.safe_get(this.B(i)) - const TokenC* H_(int i) nogil: + const TokenC* H_(int i) nogil const: return this.safe_get(this.H(i)) - const TokenC* E_(int i) nogil: + const TokenC* E_(int i) nogil const: return this.safe_get(this.E(i)) - const TokenC* L_(int i, int idx) nogil: + const TokenC* L_(int i, int idx) nogil const: return this.safe_get(this.L(i, idx)) - const TokenC* R_(int i, int idx) nogil: + const TokenC* R_(int i, int idx) nogil const: return this.safe_get(this.R(i, idx)) - const TokenC* safe_get(int i) nogil: + const TokenC* safe_get(int i) nogil const: if i < 0 or i >= this.length: return &this._empty_token else: return &this._sent[i] - int H(int i) nogil: + int H(int i) nogil const: if i < 0 or i >= this.length: return -1 return this._sent[i].head + i - int E(int i) nogil: + int E(int i) nogil const: if this._e_i <= 0 or this._e_i >= this.length: return 0 if i < 0 or i >= this._e_i: return 0 return this._ents[this._e_i - (i+1)].start - int L(int i, int idx) nogil: + int L(int i, int idx) nogil const: if idx < 1: return -1 if i < 0 or i >= this.length: @@ -135,7 +135,7 @@ cdef cppclass StateC: ptr += 1 return -1 - int R(int i, int idx) nogil: + int R(int i, int idx) nogil const: if idx < 1: return -1 if i < 0 or i >= this.length: @@ -159,39 +159,39 @@ cdef cppclass StateC: ptr -= 1 return -1 - bint empty() nogil: + bint empty() nogil const: return this._s_i <= 0 - bint eol() nogil: + bint eol() nogil const: return this.buffer_length() == 0 - bint at_break() nogil: + bint at_break() nogil const: return this._break != -1 - bint is_final() nogil: + bint is_final() nogil const: return this.stack_depth() <= 0 and this._b_i >= this.length - bint has_head(int i) nogil: + bint has_head(int i) nogil const: return this.safe_get(i).head != 0 - int n_L(int i) nogil: + int n_L(int i) nogil const: return this.safe_get(i).l_kids - int n_R(int i) nogil: + int n_R(int i) nogil const: return this.safe_get(i).r_kids - bint stack_is_connected() nogil: + bint stack_is_connected() nogil const: return False - bint entity_is_open() nogil: + bint entity_is_open() nogil const: if this._e_i < 1: return False return this._ents[this._e_i-1].end == -1 - int stack_depth() nogil: + int stack_depth() nogil const: return this._s_i - int buffer_length() nogil: + int buffer_length() nogil const: if this._break != -1: return this._break - this._b_i else: diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 403aa7d3d..b9cec50ad 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -17,6 +17,7 @@ from libc.string cimport memcpy from cymem.cymem cimport Pool from .stateclass cimport StateClass +from ._state cimport StateC DEF NON_MONOTONIC = True @@ -57,7 +58,7 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no cost += 1 if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)): cost += 1 - cost += Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0 + cost += Break.is_valid(stcls.c, -1) and Break.move_cost(stcls, gold) == 0 return cost @@ -70,7 +71,7 @@ cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nog cost += gold.heads[target] == B_i if gold.heads[B_i] == B_i or gold.heads[B_i] < target: break - if Break.is_valid(stcls, -1) and Break.move_cost(stcls, gold) == 0: + if Break.is_valid(stcls.c, -1) and Break.move_cost(stcls, gold) == 0: cost += 1 return cost @@ -115,11 +116,11 @@ cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil: cdef class Shift: @staticmethod - cdef bint is_valid(StateClass st, int label) nogil: - return st.buffer_length() >= 2 and not st.c.shifted[st.B(0)] and not st.B_(0).sent_start + cdef bint is_valid(const StateC* st, int label) nogil: + return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and not st.B_(0).sent_start @staticmethod - cdef int transition(StateClass st, int label) nogil: + cdef int transition(StateC* st, int label) nogil: st.push() st.fast_forward() @@ -138,11 +139,11 @@ cdef class Shift: cdef class Reduce: @staticmethod - cdef bint is_valid(StateClass st, int label) nogil: + cdef bint is_valid(const StateC* st, int label) nogil: return st.stack_depth() >= 2 @staticmethod - cdef int transition(StateClass st, int label) nogil: + cdef int transition(StateC* st, int label) nogil: if st.has_head(st.S(0)): st.pop() else: @@ -164,11 +165,11 @@ cdef class Reduce: cdef class LeftArc: @staticmethod - cdef bint is_valid(StateClass st, int label) nogil: + cdef bint is_valid(const StateC* st, int label) nogil: return not st.B_(0).sent_start @staticmethod - cdef int transition(StateClass st, int label) nogil: + cdef int transition(StateC* st, int label) nogil: st.add_arc(st.B(0), st.S(0), label) st.pop() st.fast_forward() @@ -197,11 +198,11 @@ cdef class LeftArc: cdef class RightArc: @staticmethod - cdef bint is_valid(StateClass st, int label) nogil: + cdef bint is_valid(const StateC* st, int label) nogil: return not st.B_(0).sent_start @staticmethod - cdef int transition(StateClass st, int label) nogil: + cdef int transition(StateC* st, int label) nogil: st.add_arc(st.S(0), st.B(0), label) st.push() st.fast_forward() @@ -226,7 +227,7 @@ cdef class RightArc: cdef class Break: @staticmethod - cdef bint is_valid(StateClass st, int label) nogil: + cdef bint is_valid(const StateC* st, int label) nogil: cdef int i if not USE_BREAK: return False @@ -243,7 +244,7 @@ cdef class Break: return True @staticmethod - cdef int transition(StateClass st, int label) nogil: + cdef int transition(StateC* st, int label) nogil: st.set_break(st.B(0)) st.fast_forward() @@ -396,11 +397,11 @@ cdef class ArcEager(TransitionSystem): cdef int set_valid(self, int* output, StateClass stcls) nogil: cdef bint[N_MOVES] is_valid - is_valid[SHIFT] = Shift.is_valid(stcls, -1) - is_valid[REDUCE] = Reduce.is_valid(stcls, -1) - is_valid[LEFT] = LeftArc.is_valid(stcls, -1) - is_valid[RIGHT] = RightArc.is_valid(stcls, -1) - is_valid[BREAK] = Break.is_valid(stcls, -1) + is_valid[SHIFT] = Shift.is_valid(stcls.c, -1) + is_valid[REDUCE] = Reduce.is_valid(stcls.c, -1) + is_valid[LEFT] = LeftArc.is_valid(stcls.c, -1) + is_valid[RIGHT] = RightArc.is_valid(stcls.c, -1) + is_valid[BREAK] = Break.is_valid(stcls.c, -1) cdef int i for i in range(self.n_moves): output[i] = is_valid[self.c[i].move] @@ -430,7 +431,7 @@ cdef class ArcEager(TransitionSystem): n_gold = 0 for i in range(self.n_moves): - if self.c[i].is_valid(stcls, self.c[i].label): + if self.c[i].is_valid(stcls.c, self.c[i].label): is_valid[i] = True move = self.c[i].move label = self.c[i].label diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index da2887318..62d4dc0e2 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -11,6 +11,7 @@ from ..gold cimport GoldParse from ..attrs cimport ENT_TYPE, ENT_IOB from .stateclass cimport StateClass +from ._state cimport StateC cdef enum: @@ -150,11 +151,11 @@ cdef class BiluoPushDown(TransitionSystem): cdef class Missing: @staticmethod - cdef bint is_valid(StateClass st, int label) nogil: + cdef bint is_valid(const StateC* st, int label) nogil: return False @staticmethod - cdef int transition(StateClass s, int label) nogil: + cdef int transition(StateC* s, int label) nogil: pass @staticmethod @@ -164,7 +165,7 @@ cdef class Missing: cdef class Begin: @staticmethod - cdef bint is_valid(StateClass st, int label) nogil: + cdef bint is_valid(const StateC* st, int label) nogil: # Ensure we don't clobber preset entities. If no entity preset, # ent_iob is 0 cdef int preset_ent_iob = st.B_(0).ent_iob @@ -188,7 +189,7 @@ cdef class Begin: return label != 0 and not st.entity_is_open() @staticmethod - cdef int transition(StateClass st, int label) nogil: + cdef int transition(StateC* st, int label) nogil: st.open_ent(label) st.set_ent_tag(st.B(0), 3, label) st.push() @@ -214,7 +215,7 @@ cdef class Begin: cdef class In: @staticmethod - cdef bint is_valid(StateClass st, int label) nogil: + cdef bint is_valid(const StateC* st, int label) nogil: cdef int preset_ent_iob = st.B_(0).ent_iob if preset_ent_iob == 2: return False @@ -230,7 +231,7 @@ cdef class In: return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label @staticmethod - cdef int transition(StateClass st, int label) nogil: + cdef int transition(StateC* st, int label) nogil: st.set_ent_tag(st.B(0), 1, label) st.push() st.pop() @@ -266,13 +267,13 @@ cdef class In: cdef class Last: @staticmethod - cdef bint is_valid(StateClass st, int label) nogil: + cdef bint is_valid(const StateC* st, int label) nogil: if st.B_(1).ent_iob == 1: return False return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label @staticmethod - cdef int transition(StateClass st, int label) nogil: + cdef int transition(StateC* st, int label) nogil: st.close_ent() st.set_ent_tag(st.B(0), 1, label) st.push() @@ -308,7 +309,7 @@ cdef class Last: cdef class Unit: @staticmethod - cdef bint is_valid(StateClass st, int label) nogil: + cdef bint is_valid(const StateC* st, int label) nogil: cdef int preset_ent_iob = st.B_(0).ent_iob if preset_ent_iob == 2: return False @@ -321,7 +322,7 @@ cdef class Unit: return label != 0 and not st.entity_is_open() @staticmethod - cdef int transition(StateClass st, int label) nogil: + cdef int transition(StateC* st, int label) nogil: st.open_ent(label) st.close_ent() st.set_ent_tag(st.B(0), 3, label) @@ -348,7 +349,7 @@ cdef class Unit: cdef class Out: @staticmethod - cdef bint is_valid(StateClass st, int label) nogil: + cdef bint is_valid(const StateC* st, int label) nogil: cdef int preset_ent_iob = st.B_(0).ent_iob if preset_ent_iob == 3: return False @@ -357,7 +358,7 @@ cdef class Out: return not st.entity_is_open() @staticmethod - cdef int transition(StateClass st, int label) nogil: + cdef int transition(StateC* st, int label) nogil: st.set_ent_tag(st.B(0), 2, 0) st.push() st.pop() diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 3be7edf1a..fb9b001d5 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -124,7 +124,7 @@ cdef class Parser: with gil: move_name = self.moves.move_name(action.move, action.label) raise ValueError("Illegal action: %s" % move_name) - action.do(stcls, action.label) + action.do(stcls.c, action.label) memset(eg.c.scores, 0, sizeof(eg.c.scores[0]) * eg.c.nr_class) memset(eg.c.costs, 0, sizeof(eg.c.costs[0]) * eg.c.nr_class) for i in range(eg.c.nr_class): @@ -151,7 +151,7 @@ cdef class Parser: guess = VecVec.arg_max_if_true(eg.c.scores, eg.c.is_valid, eg.c.nr_class) action = self.moves.c[eg.guess] - action.do(stcls, action.label) + action.do(stcls.c, action.label) loss += eg.costs[eg.guess] eg.reset_classes(eg.nr_class) return loss @@ -230,7 +230,7 @@ cdef class StepwiseState: action = self.parser.moves.c[clas] else: action = self.parser.moves.lookup_transition(action_name) - action.do(self.stcls, action.label) + action.do(self.stcls.c, action.label) def finish(self): if self.stcls.is_final(): diff --git a/spacy/syntax/transition_system.pxd b/spacy/syntax/transition_system.pxd index 673582bb8..23d3561c4 100644 --- a/spacy/syntax/transition_system.pxd +++ b/spacy/syntax/transition_system.pxd @@ -7,6 +7,7 @@ from ..gold cimport GoldParseC from ..strings cimport StringStore from .stateclass cimport StateClass +from ._state cimport StateC cdef struct Transition: @@ -16,16 +17,16 @@ cdef struct Transition: weight_t score - bint (*is_valid)(StateClass state, int label) nogil + bint (*is_valid)(const StateC* state, int label) nogil weight_t (*get_cost)(StateClass state, const GoldParseC* gold, int label) nogil - int (*do)(StateClass state, int label) nogil + int (*do)(StateC* state, int label) nogil ctypedef weight_t (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil ctypedef weight_t (*move_cost_func_t)(StateClass state, const GoldParseC* gold) nogil ctypedef weight_t (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) nogil -ctypedef int (*do_func_t)(StateClass state, int label) nogil +ctypedef int (*do_func_t)(StateC* state, int label) nogil cdef class TransitionSystem: diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index 775307aee..0b2a03202 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -64,12 +64,12 @@ cdef class TransitionSystem: def is_valid(self, StateClass stcls, move_name): action = self.lookup_transition(move_name) - return action.is_valid(stcls, action.label) + return action.is_valid(stcls.c, action.label) cdef int set_valid(self, int* is_valid, StateClass stcls) nogil: cdef int i for i in range(self.n_moves): - is_valid[i] = self.c[i].is_valid(stcls, self.c[i].label) + is_valid[i] = self.c[i].is_valid(stcls.c, self.c[i].label) cdef int set_costs(self, int* is_valid, weight_t* costs, StateClass stcls, GoldParse gold) except -1: