* Move StateClass into interface of transition functions

This commit is contained in:
Matthew Honnibal 2015-06-10 01:35:28 +02:00
parent 4b98b3e9c8
commit d68c686ec1
6 changed files with 86 additions and 76 deletions

View File

@ -120,11 +120,11 @@ cdef class Shift:
return not st.eol()
@staticmethod
cdef int transition(State* state, int label) except -1:
cdef int transition(StateClass state, int label) except -1:
# Set the dep label, in case we need it after we reduce
if NON_MONOTONIC:
state.sent[state.i].dep = label
push_stack(state)
state._sent[state.B(0)].dep = label
state.push()
@staticmethod
cdef int cost(StateClass st, const GoldParseC* gold, int label) except -1:
@ -148,10 +148,10 @@ cdef class Reduce:
return st.stack_depth() >= 2 and st.has_head(st.S(0))
@staticmethod
cdef int transition(State* state, int label) except -1:
if NON_MONOTONIC and not has_head(get_s0(state)) and state.stack_len >= 2:
add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep)
pop_stack(state)
cdef int transition(StateClass st, int label) except -1:
if NON_MONOTONIC and not st.has_head(st.S(0)) and st.stack_depth() >= 2:
st.add_arc(st.S(1), st.S(0), st.S_(0).dep)
st.pop()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
@ -178,13 +178,13 @@ cdef class LeftArc:
return st.stack_depth() >= 1 and not st.has_head(st.S(0))
@staticmethod
cdef int transition(State* state, int label) except -1:
cdef int transition(StateClass st, int label) except -1:
# Interpret left-arcs from EOL as attachment to root
if at_eol(state):
add_dep(state, state.stack[0], state.stack[0], label)
if st.eol():
st.add_arc(st.S(0), st.S(0), label)
else:
add_dep(state, state.i, state.stack[0], label)
pop_stack(state)
st.add_arc(st.B(0), st.S(0), label)
st.pop()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
@ -208,9 +208,9 @@ cdef class RightArc:
return st.stack_depth() >= 1 and not st.eol()
@staticmethod
cdef int transition(State* state, int label) except -1:
add_dep(state, state.stack[0], state.i, label)
push_stack(state)
cdef int transition(StateClass st, int label) except -1:
st.add_arc(st.S(0), st.B(0), label)
st.push()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
@ -256,13 +256,12 @@ cdef class Break:
return True
@staticmethod
cdef int transition(State* state, int label) except -1:
state.sent[state.i-1].sent_end = True
while state.stack_len != 0:
if get_s0(state).head == 0:
get_s0(state).dep = label
state.stack -= 1
state.stack_len -= 1
cdef int transition(StateClass st, int label) except -1:
st.set_sent_end(st.B(0)-1)
while not st.empty():
if not st.has_head(st.S(0)):
st._sent[st.S(0)].dep = label
st.pop()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
@ -370,11 +369,11 @@ cdef class ArcEager(TransitionSystem):
cdef int initialize_state(self, State* state) except -1:
push_stack(state)
cdef int finalize_state(self, State* state) except -1:
cdef int finalize_state(self, StateClass st) except -1:
cdef int root_label = self.strings['ROOT']
for i in range(state.sent_len):
if state.sent[i].head == 0 and state.sent[i].dep == 0:
state.sent[i].dep = root_label
for i in range(st.length):
if st._sent[i].head == 0 and st._sent[i].dep == 0:
st._sent[i].dep = root_label
cdef int set_valid(self, bint* output, StateClass stcls) except -1:
cdef bint[N_MOVES] is_valid

View File

@ -158,7 +158,7 @@ cdef class Missing:
return False
@staticmethod
cdef int transition(State* s, int label) except -1:
cdef int transition(StateClass s, int label) except -1:
raise NotImplementedError
@staticmethod
@ -172,15 +172,11 @@ cdef class Begin:
return label != 0 and not st.entity_is_open()
@staticmethod
cdef int transition(State* s, int label) except -1:
s.ent += 1
s.ents_len += 1
s.ent.start = s.i
s.ent.label = label
s.ent.end = 0
s.sent[s.i].ent_iob = 3
s.sent[s.i].ent_type = label
s.i += 1
cdef int transition(StateClass st, int label) except -1:
st.open_ent(label)
st.set_ent_tag(st.B(0), 3, label)
st.push()
st.pop()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
@ -206,10 +202,10 @@ cdef class In:
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
@staticmethod
cdef int transition(State* s, int label) except -1:
s.sent[s.i].ent_iob = 1
s.sent[s.i].ent_type = label
s.i += 1
cdef int transition(StateClass st, int label) except -1:
st.set_ent_tag(st.B(0), 1, label)
st.push()
st.pop()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
@ -246,11 +242,10 @@ cdef class Last:
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
@staticmethod
cdef int transition(State* s, int label) except -1:
s.ent.end = s.i+1
s.sent[s.i].ent_iob = 1
s.sent[s.i].ent_type = label
s.i += 1
cdef int transition(StateClass st, int label) except -1:
st.close_ent()
st.push()
st.pop()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
@ -286,15 +281,12 @@ cdef class Unit:
return label != 0 and not st.entity_is_open()
@staticmethod
cdef int transition(State* s, int label) except -1:
s.ent += 1
s.ents_len += 1
s.ent.start = s.i
s.ent.label = label
s.ent.end = s.i+1
s.sent[s.i].ent_iob = 3
s.sent[s.i].ent_type = label
s.i += 1
cdef int transition(StateClass st, int label) except -1:
st.open_ent(label)
st.close_ent()
st.set_ent_tag(st.B(0), 3, label)
st.push()
st.pop()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
@ -320,9 +312,10 @@ cdef class Out:
return not st.entity_is_open()
@staticmethod
cdef int transition(State* s, int label) except -1:
s.sent[s.i].ent_iob = 2
s.i += 1
cdef int transition(StateClass st, int label) except -1:
st.set_ent_tag(st.B(0), 2, 0)
st.push()
st.pop()
@staticmethod
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:

View File

@ -106,15 +106,17 @@ cdef class Parser:
cdef State* state = new_state(mem, tokens.data, tokens.length)
self.moves.initialize_state(state)
cdef StateClass stcls = StateClass(state.sent_len)
stcls.from_struct(state)
cdef Transition guess
while not is_final(state):
stcls.from_struct(state)
words = [w.orth_ for w in tokens]
while not stcls.is_final():
#print stcls.print_state(words)
_new_fill_context(context, stcls)
scores = self.model.score(context)
guess = self.moves.best_valid(scores, stcls)
guess.do(state, guess.label)
self.moves.finalize_state(state)
tokens.set_parse(state.sent)
guess.do(stcls, guess.label)
self.moves.finalize_state(stcls)
tokens.set_parse(stcls._sent)
cdef int _beam_parse(self, Tokens tokens) except -1:
cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width)
@ -123,8 +125,9 @@ cdef class Parser:
while not beam.is_done:
self._advance_beam(beam, None, False)
state = <State*>beam.at(0)
self.moves.finalize_state(state)
tokens.set_parse(state.sent)
#self.moves.finalize_state(state)
#tokens.set_parse(state.sent)
raise Exception
def _greedy_train(self, Tokens tokens, GoldParse gold):
cdef Pool mem = Pool()
@ -137,17 +140,18 @@ cdef class Parser:
cdef Transition guess
cdef Transition best
cdef StateClass stcls = StateClass(state.sent_len)
stcls.from_struct(state)
cdef atom_t[CONTEXT_SIZE] context
loss = 0
while not is_final(state):
stcls.from_struct(state)
words = [w.orth_ for w in tokens]
while not stcls.is_final():
_new_fill_context(context, stcls)
scores = self.model.score(context)
guess = self.moves.best_valid(scores, stcls)
best = self.moves.best_gold(scores, stcls, gold)
cost = guess.get_cost(stcls, &gold.c, guess.label)
self.model.update(context, guess.clas, best.clas, cost)
guess.do(state, guess.label)
guess.do(stcls, guess.label)
loss += cost
return loss
@ -203,14 +207,16 @@ cdef class Parser:
cdef Pool mem = Pool()
cdef State* state = new_state(mem, tokens.data, tokens.length)
self.moves.initialize_state(state)
cdef StateClass stcls = StateClass(state.sent_len)
stcls.from_struct(state)
cdef class_t clas
cdef int n_feats
for clas in hist:
fill_context(context, state)
_new_fill_context(context, stcls)
feats = self.model._extractor.get_feats(context, &n_feats)
count_feats(counts[clas], feats, n_feats, inc)
self.moves.c[clas].do(state, self.moves.c[clas].label)
self.moves.c[clas].do(stcls, self.moves.c[clas].label)
# These are passed as callbacks to thinc.search.Beam
@ -220,7 +226,8 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
src = <const State*>_src
moves = <const Transition*>_moves
copy_state(dest, src)
moves[clas].do(dest, moves[clas].label)
raise Exception
#moves[clas].do(dest, moves[clas].label)
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:

View File

@ -126,7 +126,7 @@ cdef class StateClass:
return self._b_i >= self.length
cdef bint is_final(self) nogil:
return self.eol() and self.empty()
return self.eol() and self.stack_depth() <= 1
cdef bint has_head(self, int i) nogil:
return self.safe_get(i).head != 0
@ -196,7 +196,7 @@ cdef class StateClass:
self._sent[i].ent_type = ent_type
cdef void set_sent_end(self, int i) nogil:
if 0 < i < self.length:
if 0 <= i < self.length:
self._sent[i].sent_end = True
cdef void clone(self, StateClass src) nogil:
@ -208,6 +208,17 @@ cdef class StateClass:
self._s_i = src._s_i
self._e_i = src._e_i
def print_state(self, words):
words = list(words) + ['_']
top = words[self.S(0)] + '_%d' % self.H(self.S(0))
second = words[self.S(1)] + '_%d' % self.H(self.S(1))
third = words[self.S(2)] + '_%d' % self.H(self.S(2))
n0 = words[self.B(0)]
n1 = words[self.B(1)]
return ' '.join((str(self.stack_depth()), third, second, top, '|', n0, n1))
# From https://en.wikipedia.org/wiki/Hamming_weight
cdef inline uint32_t _popcount(uint32_t x) nogil:

View File

@ -19,14 +19,14 @@ cdef struct Transition:
bint (*is_valid)(StateClass state, int label) except -1
int (*get_cost)(StateClass state, const GoldParseC* gold, int label) except -1
int (*do)(State* state, int label) except -1
int (*do)(StateClass state, int label) except -1
ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) except -1
ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
ctypedef int (*do_func_t)(State* state, int label) except -1
ctypedef int (*do_func_t)(StateClass state, int label) except -1
cdef class TransitionSystem:
@ -37,7 +37,7 @@ cdef class TransitionSystem:
cdef readonly int n_moves
cdef int initialize_state(self, State* state) except -1
cdef int finalize_state(self, State* state) except -1
cdef int finalize_state(self, StateClass state) except -1
cdef int preprocess_gold(self, GoldParse gold) except -1

View File

@ -32,7 +32,7 @@ cdef class TransitionSystem:
cdef int initialize_state(self, State* state) except -1:
pass
cdef int finalize_state(self, State* state) except -1:
cdef int finalize_state(self, StateClass state) except -1:
pass
cdef int preprocess_gold(self, GoldParse gold) except -1: