diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index a99c383f5..546ea5281 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -120,11 +120,11 @@ cdef class Shift: return not st.eol() @staticmethod - cdef int transition(State* state, int label) except -1: + cdef int transition(StateClass state, int label) except -1: # Set the dep label, in case we need it after we reduce if NON_MONOTONIC: - state.sent[state.i].dep = label - push_stack(state) + state._sent[state.B(0)].dep = label + state.push() @staticmethod cdef int cost(StateClass st, const GoldParseC* gold, int label) except -1: @@ -148,10 +148,10 @@ cdef class Reduce: return st.stack_depth() >= 2 and st.has_head(st.S(0)) @staticmethod - cdef int transition(State* state, int label) except -1: - if NON_MONOTONIC and not has_head(get_s0(state)) and state.stack_len >= 2: - add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep) - pop_stack(state) + cdef int transition(StateClass st, int label) except -1: + if NON_MONOTONIC and not st.has_head(st.S(0)) and st.stack_depth() >= 2: + st.add_arc(st.S(1), st.S(0), st.S_(0).dep) + st.pop() @staticmethod cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1: @@ -178,13 +178,13 @@ cdef class LeftArc: return st.stack_depth() >= 1 and not st.has_head(st.S(0)) @staticmethod - cdef int transition(State* state, int label) except -1: + cdef int transition(StateClass st, int label) except -1: # Interpret left-arcs from EOL as attachment to root - if at_eol(state): - add_dep(state, state.stack[0], state.stack[0], label) + if st.eol(): + st.add_arc(st.S(0), st.S(0), label) else: - add_dep(state, state.i, state.stack[0], label) - pop_stack(state) + st.add_arc(st.B(0), st.S(0), label) + st.pop() @staticmethod cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1: @@ -208,9 +208,9 @@ cdef class RightArc: return st.stack_depth() >= 1 and not st.eol() @staticmethod - cdef int transition(State* state, int label) except -1: - add_dep(state, state.stack[0], state.i, label) - push_stack(state) + cdef int transition(StateClass st, int label) except -1: + st.add_arc(st.S(0), st.B(0), label) + st.push() @staticmethod cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1: @@ -256,13 +256,12 @@ cdef class Break: return True @staticmethod - cdef int transition(State* state, int label) except -1: - state.sent[state.i-1].sent_end = True - while state.stack_len != 0: - if get_s0(state).head == 0: - get_s0(state).dep = label - state.stack -= 1 - state.stack_len -= 1 + cdef int transition(StateClass st, int label) except -1: + st.set_sent_end(st.B(0)-1) + while not st.empty(): + if not st.has_head(st.S(0)): + st._sent[st.S(0)].dep = label + st.pop() @staticmethod cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1: @@ -370,11 +369,11 @@ cdef class ArcEager(TransitionSystem): cdef int initialize_state(self, State* state) except -1: push_stack(state) - cdef int finalize_state(self, State* state) except -1: + cdef int finalize_state(self, StateClass st) except -1: cdef int root_label = self.strings['ROOT'] - for i in range(state.sent_len): - if state.sent[i].head == 0 and state.sent[i].dep == 0: - state.sent[i].dep = root_label + for i in range(st.length): + if st._sent[i].head == 0 and st._sent[i].dep == 0: + st._sent[i].dep = root_label cdef int set_valid(self, bint* output, StateClass stcls) except -1: cdef bint[N_MOVES] is_valid diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index 01aec7769..833d1f299 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -158,7 +158,7 @@ cdef class Missing: return False @staticmethod - cdef int transition(State* s, int label) except -1: + cdef int transition(StateClass s, int label) except -1: raise NotImplementedError @staticmethod @@ -172,15 +172,11 @@ cdef class Begin: return label != 0 and not st.entity_is_open() @staticmethod - cdef int transition(State* s, int label) except -1: - s.ent += 1 - s.ents_len += 1 - s.ent.start = s.i - s.ent.label = label - s.ent.end = 0 - s.sent[s.i].ent_iob = 3 - s.sent[s.i].ent_type = label - s.i += 1 + cdef int transition(StateClass st, int label) except -1: + st.open_ent(label) + st.set_ent_tag(st.B(0), 3, label) + st.push() + st.pop() @staticmethod cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1: @@ -206,10 +202,10 @@ cdef class In: return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label @staticmethod - cdef int transition(State* s, int label) except -1: - s.sent[s.i].ent_iob = 1 - s.sent[s.i].ent_type = label - s.i += 1 + cdef int transition(StateClass st, int label) except -1: + st.set_ent_tag(st.B(0), 1, label) + st.push() + st.pop() @staticmethod cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1: @@ -246,11 +242,10 @@ cdef class Last: return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label @staticmethod - cdef int transition(State* s, int label) except -1: - s.ent.end = s.i+1 - s.sent[s.i].ent_iob = 1 - s.sent[s.i].ent_type = label - s.i += 1 + cdef int transition(StateClass st, int label) except -1: + st.close_ent() + st.push() + st.pop() @staticmethod cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1: @@ -286,15 +281,12 @@ cdef class Unit: return label != 0 and not st.entity_is_open() @staticmethod - cdef int transition(State* s, int label) except -1: - s.ent += 1 - s.ents_len += 1 - s.ent.start = s.i - s.ent.label = label - s.ent.end = s.i+1 - s.sent[s.i].ent_iob = 3 - s.sent[s.i].ent_type = label - s.i += 1 + cdef int transition(StateClass st, int label) except -1: + st.open_ent(label) + st.close_ent() + st.set_ent_tag(st.B(0), 3, label) + st.push() + st.pop() @staticmethod cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1: @@ -320,9 +312,10 @@ cdef class Out: return not st.entity_is_open() @staticmethod - cdef int transition(State* s, int label) except -1: - s.sent[s.i].ent_iob = 2 - s.i += 1 + cdef int transition(StateClass st, int label) except -1: + st.set_ent_tag(st.B(0), 2, 0) + st.push() + st.pop() @staticmethod cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1: diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index 2d4d2c3dc..1ff5a523f 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -106,15 +106,17 @@ cdef class Parser: cdef State* state = new_state(mem, tokens.data, tokens.length) self.moves.initialize_state(state) cdef StateClass stcls = StateClass(state.sent_len) + stcls.from_struct(state) cdef Transition guess - while not is_final(state): - stcls.from_struct(state) + words = [w.orth_ for w in tokens] + while not stcls.is_final(): + #print stcls.print_state(words) _new_fill_context(context, stcls) scores = self.model.score(context) guess = self.moves.best_valid(scores, stcls) - guess.do(state, guess.label) - self.moves.finalize_state(state) - tokens.set_parse(state.sent) + guess.do(stcls, guess.label) + self.moves.finalize_state(stcls) + tokens.set_parse(stcls._sent) cdef int _beam_parse(self, Tokens tokens) except -1: cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width) @@ -123,8 +125,9 @@ cdef class Parser: while not beam.is_done: self._advance_beam(beam, None, False) state = beam.at(0) - self.moves.finalize_state(state) - tokens.set_parse(state.sent) + #self.moves.finalize_state(state) + #tokens.set_parse(state.sent) + raise Exception def _greedy_train(self, Tokens tokens, GoldParse gold): cdef Pool mem = Pool() @@ -137,17 +140,18 @@ cdef class Parser: cdef Transition guess cdef Transition best cdef StateClass stcls = StateClass(state.sent_len) + stcls.from_struct(state) cdef atom_t[CONTEXT_SIZE] context loss = 0 - while not is_final(state): - stcls.from_struct(state) + words = [w.orth_ for w in tokens] + while not stcls.is_final(): _new_fill_context(context, stcls) scores = self.model.score(context) guess = self.moves.best_valid(scores, stcls) best = self.moves.best_gold(scores, stcls, gold) cost = guess.get_cost(stcls, &gold.c, guess.label) self.model.update(context, guess.clas, best.clas, cost) - guess.do(state, guess.label) + guess.do(stcls, guess.label) loss += cost return loss @@ -203,14 +207,16 @@ cdef class Parser: cdef Pool mem = Pool() cdef State* state = new_state(mem, tokens.data, tokens.length) self.moves.initialize_state(state) + cdef StateClass stcls = StateClass(state.sent_len) + stcls.from_struct(state) cdef class_t clas cdef int n_feats for clas in hist: - fill_context(context, state) + _new_fill_context(context, stcls) feats = self.model._extractor.get_feats(context, &n_feats) count_feats(counts[clas], feats, n_feats, inc) - self.moves.c[clas].do(state, self.moves.c[clas].label) + self.moves.c[clas].do(stcls, self.moves.c[clas].label) # These are passed as callbacks to thinc.search.Beam @@ -220,7 +226,8 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) src = _src moves = _moves copy_state(dest, src) - moves[clas].do(dest, moves[clas].label) + raise Exception + #moves[clas].do(dest, moves[clas].label) cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: diff --git a/spacy/syntax/stateclass.pyx b/spacy/syntax/stateclass.pyx index d15a2b650..81227db26 100644 --- a/spacy/syntax/stateclass.pyx +++ b/spacy/syntax/stateclass.pyx @@ -126,7 +126,7 @@ cdef class StateClass: return self._b_i >= self.length cdef bint is_final(self) nogil: - return self.eol() and self.empty() + return self.eol() and self.stack_depth() <= 1 cdef bint has_head(self, int i) nogil: return self.safe_get(i).head != 0 @@ -196,7 +196,7 @@ cdef class StateClass: self._sent[i].ent_type = ent_type cdef void set_sent_end(self, int i) nogil: - if 0 < i < self.length: + if 0 <= i < self.length: self._sent[i].sent_end = True cdef void clone(self, StateClass src) nogil: @@ -207,6 +207,17 @@ cdef class StateClass: self._b_i = src._b_i self._s_i = src._s_i self._e_i = src._e_i + + def print_state(self, words): + words = list(words) + ['_'] + top = words[self.S(0)] + '_%d' % self.H(self.S(0)) + second = words[self.S(1)] + '_%d' % self.H(self.S(1)) + third = words[self.S(2)] + '_%d' % self.H(self.S(2)) + n0 = words[self.B(0)] + n1 = words[self.B(1)] + return ' '.join((str(self.stack_depth()), third, second, top, '|', n0, n1)) + + # From https://en.wikipedia.org/wiki/Hamming_weight diff --git a/spacy/syntax/transition_system.pxd b/spacy/syntax/transition_system.pxd index 5027e66be..f144d282e 100644 --- a/spacy/syntax/transition_system.pxd +++ b/spacy/syntax/transition_system.pxd @@ -19,14 +19,14 @@ cdef struct Transition: bint (*is_valid)(StateClass state, int label) except -1 int (*get_cost)(StateClass state, const GoldParseC* gold, int label) except -1 - int (*do)(State* state, int label) except -1 + int (*do)(StateClass state, int label) except -1 ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1 ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) except -1 ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1 -ctypedef int (*do_func_t)(State* state, int label) except -1 +ctypedef int (*do_func_t)(StateClass state, int label) except -1 cdef class TransitionSystem: @@ -37,7 +37,7 @@ cdef class TransitionSystem: cdef readonly int n_moves cdef int initialize_state(self, State* state) except -1 - cdef int finalize_state(self, State* state) except -1 + cdef int finalize_state(self, StateClass state) except -1 cdef int preprocess_gold(self, GoldParse gold) except -1 diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index f1dd06320..6d972bcf9 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -32,7 +32,7 @@ cdef class TransitionSystem: cdef int initialize_state(self, State* state) except -1: pass - cdef int finalize_state(self, State* state) except -1: + cdef int finalize_state(self, StateClass state) except -1: pass cdef int preprocess_gold(self, GoldParse gold) except -1: