mirror of https://github.com/explosion/spaCy.git
* Move StateClass into interface of transition functions
This commit is contained in:
parent
4b98b3e9c8
commit
d68c686ec1
|
@ -120,11 +120,11 @@ cdef class Shift:
|
|||
return not st.eol()
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* state, int label) except -1:
|
||||
cdef int transition(StateClass state, int label) except -1:
|
||||
# Set the dep label, in case we need it after we reduce
|
||||
if NON_MONOTONIC:
|
||||
state.sent[state.i].dep = label
|
||||
push_stack(state)
|
||||
state._sent[state.B(0)].dep = label
|
||||
state.push()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass st, const GoldParseC* gold, int label) except -1:
|
||||
|
@ -148,10 +148,10 @@ cdef class Reduce:
|
|||
return st.stack_depth() >= 2 and st.has_head(st.S(0))
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* state, int label) except -1:
|
||||
if NON_MONOTONIC and not has_head(get_s0(state)) and state.stack_len >= 2:
|
||||
add_dep(state, state.stack[-1], state.stack[0], get_s0(state).dep)
|
||||
pop_stack(state)
|
||||
cdef int transition(StateClass st, int label) except -1:
|
||||
if NON_MONOTONIC and not st.has_head(st.S(0)) and st.stack_depth() >= 2:
|
||||
st.add_arc(st.S(1), st.S(0), st.S_(0).dep)
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
|
@ -178,13 +178,13 @@ cdef class LeftArc:
|
|||
return st.stack_depth() >= 1 and not st.has_head(st.S(0))
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* state, int label) except -1:
|
||||
cdef int transition(StateClass st, int label) except -1:
|
||||
# Interpret left-arcs from EOL as attachment to root
|
||||
if at_eol(state):
|
||||
add_dep(state, state.stack[0], state.stack[0], label)
|
||||
if st.eol():
|
||||
st.add_arc(st.S(0), st.S(0), label)
|
||||
else:
|
||||
add_dep(state, state.i, state.stack[0], label)
|
||||
pop_stack(state)
|
||||
st.add_arc(st.B(0), st.S(0), label)
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
|
@ -208,9 +208,9 @@ cdef class RightArc:
|
|||
return st.stack_depth() >= 1 and not st.eol()
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* state, int label) except -1:
|
||||
add_dep(state, state.stack[0], state.i, label)
|
||||
push_stack(state)
|
||||
cdef int transition(StateClass st, int label) except -1:
|
||||
st.add_arc(st.S(0), st.B(0), label)
|
||||
st.push()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
|
@ -256,13 +256,12 @@ cdef class Break:
|
|||
return True
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* state, int label) except -1:
|
||||
state.sent[state.i-1].sent_end = True
|
||||
while state.stack_len != 0:
|
||||
if get_s0(state).head == 0:
|
||||
get_s0(state).dep = label
|
||||
state.stack -= 1
|
||||
state.stack_len -= 1
|
||||
cdef int transition(StateClass st, int label) except -1:
|
||||
st.set_sent_end(st.B(0)-1)
|
||||
while not st.empty():
|
||||
if not st.has_head(st.S(0)):
|
||||
st._sent[st.S(0)].dep = label
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
|
@ -370,11 +369,11 @@ cdef class ArcEager(TransitionSystem):
|
|||
cdef int initialize_state(self, State* state) except -1:
|
||||
push_stack(state)
|
||||
|
||||
cdef int finalize_state(self, State* state) except -1:
|
||||
cdef int finalize_state(self, StateClass st) except -1:
|
||||
cdef int root_label = self.strings['ROOT']
|
||||
for i in range(state.sent_len):
|
||||
if state.sent[i].head == 0 and state.sent[i].dep == 0:
|
||||
state.sent[i].dep = root_label
|
||||
for i in range(st.length):
|
||||
if st._sent[i].head == 0 and st._sent[i].dep == 0:
|
||||
st._sent[i].dep = root_label
|
||||
|
||||
cdef int set_valid(self, bint* output, StateClass stcls) except -1:
|
||||
cdef bint[N_MOVES] is_valid
|
||||
|
|
|
@ -158,7 +158,7 @@ cdef class Missing:
|
|||
return False
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* s, int label) except -1:
|
||||
cdef int transition(StateClass s, int label) except -1:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
|
@ -172,15 +172,11 @@ cdef class Begin:
|
|||
return label != 0 and not st.entity_is_open()
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* s, int label) except -1:
|
||||
s.ent += 1
|
||||
s.ents_len += 1
|
||||
s.ent.start = s.i
|
||||
s.ent.label = label
|
||||
s.ent.end = 0
|
||||
s.sent[s.i].ent_iob = 3
|
||||
s.sent[s.i].ent_type = label
|
||||
s.i += 1
|
||||
cdef int transition(StateClass st, int label) except -1:
|
||||
st.open_ent(label)
|
||||
st.set_ent_tag(st.B(0), 3, label)
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
|
@ -206,10 +202,10 @@ cdef class In:
|
|||
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* s, int label) except -1:
|
||||
s.sent[s.i].ent_iob = 1
|
||||
s.sent[s.i].ent_type = label
|
||||
s.i += 1
|
||||
cdef int transition(StateClass st, int label) except -1:
|
||||
st.set_ent_tag(st.B(0), 1, label)
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
|
@ -246,11 +242,10 @@ cdef class Last:
|
|||
return st.entity_is_open() and label != 0 and st.E_(0).ent_type == label
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* s, int label) except -1:
|
||||
s.ent.end = s.i+1
|
||||
s.sent[s.i].ent_iob = 1
|
||||
s.sent[s.i].ent_type = label
|
||||
s.i += 1
|
||||
cdef int transition(StateClass st, int label) except -1:
|
||||
st.close_ent()
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
|
@ -286,15 +281,12 @@ cdef class Unit:
|
|||
return label != 0 and not st.entity_is_open()
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* s, int label) except -1:
|
||||
s.ent += 1
|
||||
s.ents_len += 1
|
||||
s.ent.start = s.i
|
||||
s.ent.label = label
|
||||
s.ent.end = s.i+1
|
||||
s.sent[s.i].ent_iob = 3
|
||||
s.sent[s.i].ent_type = label
|
||||
s.i += 1
|
||||
cdef int transition(StateClass st, int label) except -1:
|
||||
st.open_ent(label)
|
||||
st.close_ent()
|
||||
st.set_ent_tag(st.B(0), 3, label)
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
|
@ -320,9 +312,10 @@ cdef class Out:
|
|||
return not st.entity_is_open()
|
||||
|
||||
@staticmethod
|
||||
cdef int transition(State* s, int label) except -1:
|
||||
s.sent[s.i].ent_iob = 2
|
||||
s.i += 1
|
||||
cdef int transition(StateClass st, int label) except -1:
|
||||
st.set_ent_tag(st.B(0), 2, 0)
|
||||
st.push()
|
||||
st.pop()
|
||||
|
||||
@staticmethod
|
||||
cdef int cost(StateClass s, const GoldParseC* gold, int label) except -1:
|
||||
|
|
|
@ -106,15 +106,17 @@ cdef class Parser:
|
|||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
||||
self.moves.initialize_state(state)
|
||||
cdef StateClass stcls = StateClass(state.sent_len)
|
||||
stcls.from_struct(state)
|
||||
cdef Transition guess
|
||||
while not is_final(state):
|
||||
stcls.from_struct(state)
|
||||
words = [w.orth_ for w in tokens]
|
||||
while not stcls.is_final():
|
||||
#print stcls.print_state(words)
|
||||
_new_fill_context(context, stcls)
|
||||
scores = self.model.score(context)
|
||||
guess = self.moves.best_valid(scores, stcls)
|
||||
guess.do(state, guess.label)
|
||||
self.moves.finalize_state(state)
|
||||
tokens.set_parse(state.sent)
|
||||
guess.do(stcls, guess.label)
|
||||
self.moves.finalize_state(stcls)
|
||||
tokens.set_parse(stcls._sent)
|
||||
|
||||
cdef int _beam_parse(self, Tokens tokens) except -1:
|
||||
cdef Beam beam = Beam(self.moves.n_moves, self.cfg.beam_width)
|
||||
|
@ -123,8 +125,9 @@ cdef class Parser:
|
|||
while not beam.is_done:
|
||||
self._advance_beam(beam, None, False)
|
||||
state = <State*>beam.at(0)
|
||||
self.moves.finalize_state(state)
|
||||
tokens.set_parse(state.sent)
|
||||
#self.moves.finalize_state(state)
|
||||
#tokens.set_parse(state.sent)
|
||||
raise Exception
|
||||
|
||||
def _greedy_train(self, Tokens tokens, GoldParse gold):
|
||||
cdef Pool mem = Pool()
|
||||
|
@ -137,17 +140,18 @@ cdef class Parser:
|
|||
cdef Transition guess
|
||||
cdef Transition best
|
||||
cdef StateClass stcls = StateClass(state.sent_len)
|
||||
stcls.from_struct(state)
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
loss = 0
|
||||
while not is_final(state):
|
||||
stcls.from_struct(state)
|
||||
words = [w.orth_ for w in tokens]
|
||||
while not stcls.is_final():
|
||||
_new_fill_context(context, stcls)
|
||||
scores = self.model.score(context)
|
||||
guess = self.moves.best_valid(scores, stcls)
|
||||
best = self.moves.best_gold(scores, stcls, gold)
|
||||
cost = guess.get_cost(stcls, &gold.c, guess.label)
|
||||
self.model.update(context, guess.clas, best.clas, cost)
|
||||
guess.do(state, guess.label)
|
||||
guess.do(stcls, guess.label)
|
||||
loss += cost
|
||||
return loss
|
||||
|
||||
|
@ -203,14 +207,16 @@ cdef class Parser:
|
|||
cdef Pool mem = Pool()
|
||||
cdef State* state = new_state(mem, tokens.data, tokens.length)
|
||||
self.moves.initialize_state(state)
|
||||
cdef StateClass stcls = StateClass(state.sent_len)
|
||||
stcls.from_struct(state)
|
||||
|
||||
cdef class_t clas
|
||||
cdef int n_feats
|
||||
for clas in hist:
|
||||
fill_context(context, state)
|
||||
_new_fill_context(context, stcls)
|
||||
feats = self.model._extractor.get_feats(context, &n_feats)
|
||||
count_feats(counts[clas], feats, n_feats, inc)
|
||||
self.moves.c[clas].do(state, self.moves.c[clas].label)
|
||||
self.moves.c[clas].do(stcls, self.moves.c[clas].label)
|
||||
|
||||
|
||||
# These are passed as callbacks to thinc.search.Beam
|
||||
|
@ -220,7 +226,8 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
|
|||
src = <const State*>_src
|
||||
moves = <const Transition*>_moves
|
||||
copy_state(dest, src)
|
||||
moves[clas].do(dest, moves[clas].label)
|
||||
raise Exception
|
||||
#moves[clas].do(dest, moves[clas].label)
|
||||
|
||||
|
||||
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
||||
|
|
|
@ -126,7 +126,7 @@ cdef class StateClass:
|
|||
return self._b_i >= self.length
|
||||
|
||||
cdef bint is_final(self) nogil:
|
||||
return self.eol() and self.empty()
|
||||
return self.eol() and self.stack_depth() <= 1
|
||||
|
||||
cdef bint has_head(self, int i) nogil:
|
||||
return self.safe_get(i).head != 0
|
||||
|
@ -196,7 +196,7 @@ cdef class StateClass:
|
|||
self._sent[i].ent_type = ent_type
|
||||
|
||||
cdef void set_sent_end(self, int i) nogil:
|
||||
if 0 < i < self.length:
|
||||
if 0 <= i < self.length:
|
||||
self._sent[i].sent_end = True
|
||||
|
||||
cdef void clone(self, StateClass src) nogil:
|
||||
|
@ -207,6 +207,17 @@ cdef class StateClass:
|
|||
self._b_i = src._b_i
|
||||
self._s_i = src._s_i
|
||||
self._e_i = src._e_i
|
||||
|
||||
def print_state(self, words):
|
||||
words = list(words) + ['_']
|
||||
top = words[self.S(0)] + '_%d' % self.H(self.S(0))
|
||||
second = words[self.S(1)] + '_%d' % self.H(self.S(1))
|
||||
third = words[self.S(2)] + '_%d' % self.H(self.S(2))
|
||||
n0 = words[self.B(0)]
|
||||
n1 = words[self.B(1)]
|
||||
return ' '.join((str(self.stack_depth()), third, second, top, '|', n0, n1))
|
||||
|
||||
|
||||
|
||||
|
||||
# From https://en.wikipedia.org/wiki/Hamming_weight
|
||||
|
|
|
@ -19,14 +19,14 @@ cdef struct Transition:
|
|||
|
||||
bint (*is_valid)(StateClass state, int label) except -1
|
||||
int (*get_cost)(StateClass state, const GoldParseC* gold, int label) except -1
|
||||
int (*do)(State* state, int label) except -1
|
||||
int (*do)(StateClass state, int label) except -1
|
||||
|
||||
|
||||
ctypedef int (*get_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
|
||||
ctypedef int (*move_cost_func_t)(StateClass state, const GoldParseC* gold) except -1
|
||||
ctypedef int (*label_cost_func_t)(StateClass state, const GoldParseC* gold, int label) except -1
|
||||
|
||||
ctypedef int (*do_func_t)(State* state, int label) except -1
|
||||
ctypedef int (*do_func_t)(StateClass state, int label) except -1
|
||||
|
||||
|
||||
cdef class TransitionSystem:
|
||||
|
@ -37,7 +37,7 @@ cdef class TransitionSystem:
|
|||
cdef readonly int n_moves
|
||||
|
||||
cdef int initialize_state(self, State* state) except -1
|
||||
cdef int finalize_state(self, State* state) except -1
|
||||
cdef int finalize_state(self, StateClass state) except -1
|
||||
|
||||
cdef int preprocess_gold(self, GoldParse gold) except -1
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ cdef class TransitionSystem:
|
|||
cdef int initialize_state(self, State* state) except -1:
|
||||
pass
|
||||
|
||||
cdef int finalize_state(self, State* state) except -1:
|
||||
cdef int finalize_state(self, StateClass state) except -1:
|
||||
pass
|
||||
|
||||
cdef int preprocess_gold(self, GoldParse gold) except -1:
|
||||
|
|
Loading…
Reference in New Issue