* Add get_valid method

This commit is contained in:
Matthew Honnibal 2015-06-01 23:05:25 +02:00
parent d82f9d958d
commit c7876aa8b6
4 changed files with 28 additions and 1 deletions

View File

@ -120,6 +120,20 @@ cdef class ArcEager(TransitionSystem):
if state.sent[i].head == 0 and state.sent[i].dep == 0:
state.sent[i].dep = root_label
cdef bint* get_valid(self, const State* s) except NULL:
cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = _can_shift(s)
is_valid[REDUCE] = _can_reduce(s)
is_valid[LEFT] = _can_left(s)
is_valid[RIGHT] = _can_right(s)
is_valid[BREAK] = _can_break(s)
is_valid[CONSTITUENT] = _can_constituent(s)
is_valid[ADJUST] = _can_adjust(s)
cdef int i
for i in range(self.n_moves):
self._is_valid[i] = is_valid[self.c[i].move]
return self._is_valid
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
cdef bint[N_MOVES] is_valid
is_valid[SHIFT] = _can_shift(s)
@ -451,4 +465,3 @@ cdef inline bint _can_adjust(const State* s) nogil:
# return False
#elif b0 >= b1:
# return False
return True

View File

@ -140,6 +140,13 @@ cdef class BiluoPushDown(TransitionSystem):
t.score = score
return t
cdef bint* get_valid(self, const State* s) except NULL:
cdef int i
for i in range(self.n_moves):
m = &self.c[i]
self._is_valid[i] = _is_valid(m.move, m.label, s)
return self._is_valid
cdef int _get_cost(const Transition* self, const State* s, GoldParse gold) except -1:
if not _is_valid(self.move, self.label, s):

View File

@ -28,6 +28,7 @@ cdef class TransitionSystem:
cdef Pool mem
cdef StringStore strings
cdef const Transition* c
cdef bint* _is_valid
cdef readonly int n_moves
cdef int initialize_state(self, State* state) except -1
@ -39,6 +40,8 @@ cdef class TransitionSystem:
cdef Transition init_transition(self, int clas, int move, int label) except *
cdef bint* get_valid(self, const State* state) except NULL
cdef Transition best_valid(self, const weight_t* scores, const State* state) except *
cdef Transition best_gold(self, const weight_t* scores, const State* state,

View File

@ -15,6 +15,7 @@ cdef class TransitionSystem:
def __init__(self, StringStore string_table, dict labels_by_action):
self.mem = Pool()
self.n_moves = sum(len(labels) for labels in labels_by_action.values())
self._is_valid = <bint*>self.mem.alloc(self.n_moves, sizeof(bint))
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
cdef int i = 0
cdef int label_id
@ -43,6 +44,9 @@ cdef class TransitionSystem:
cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
raise NotImplementedError
cdef bint* get_valid(self, const State* state) except NULL:
raise NotImplementedError
cdef Transition best_gold(self, const weight_t* scores, const State* s,
GoldParse gold) except *: