From c7876aa8b6f188413ff3b7e2b1699575e8572ea9 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 1 Jun 2015 23:05:25 +0200 Subject: [PATCH] * Add get_valid method --- spacy/syntax/arc_eager.pyx | 15 ++++++++++++++- spacy/syntax/ner.pyx | 7 +++++++ spacy/syntax/transition_system.pxd | 3 +++ spacy/syntax/transition_system.pyx | 4 ++++ 4 files changed, 28 insertions(+), 1 deletion(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 2c0e3fd99..946cd540b 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -120,6 +120,20 @@ cdef class ArcEager(TransitionSystem): if state.sent[i].head == 0 and state.sent[i].dep == 0: state.sent[i].dep = root_label + cdef bint* get_valid(self, const State* s) except NULL: + cdef bint[N_MOVES] is_valid + is_valid[SHIFT] = _can_shift(s) + is_valid[REDUCE] = _can_reduce(s) + is_valid[LEFT] = _can_left(s) + is_valid[RIGHT] = _can_right(s) + is_valid[BREAK] = _can_break(s) + is_valid[CONSTITUENT] = _can_constituent(s) + is_valid[ADJUST] = _can_adjust(s) + cdef int i + for i in range(self.n_moves): + self._is_valid[i] = is_valid[self.c[i].move] + return self._is_valid + cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: cdef bint[N_MOVES] is_valid is_valid[SHIFT] = _can_shift(s) @@ -451,4 +465,3 @@ cdef inline bint _can_adjust(const State* s) nogil: # return False #elif b0 >= b1: # return False - return True diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index 76b1a530c..426a715d7 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -140,6 +140,13 @@ cdef class BiluoPushDown(TransitionSystem): t.score = score return t + cdef bint* get_valid(self, const State* s) except NULL: + cdef int i + for i in range(self.n_moves): + m = &self.c[i] + self._is_valid[i] = _is_valid(m.move, m.label, s) + return self._is_valid + cdef int _get_cost(const Transition* self, const State* s, GoldParse gold) except -1: if not _is_valid(self.move, self.label, s): diff --git a/spacy/syntax/transition_system.pxd b/spacy/syntax/transition_system.pxd index 3ac1b62f6..57f1943b2 100644 --- a/spacy/syntax/transition_system.pxd +++ b/spacy/syntax/transition_system.pxd @@ -28,6 +28,7 @@ cdef class TransitionSystem: cdef Pool mem cdef StringStore strings cdef const Transition* c + cdef bint* _is_valid cdef readonly int n_moves cdef int initialize_state(self, State* state) except -1 @@ -39,6 +40,8 @@ cdef class TransitionSystem: cdef Transition init_transition(self, int clas, int move, int label) except * + cdef bint* get_valid(self, const State* state) except NULL + cdef Transition best_valid(self, const weight_t* scores, const State* state) except * cdef Transition best_gold(self, const weight_t* scores, const State* state, diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index 0fea8d8c4..67c33155c 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -15,6 +15,7 @@ cdef class TransitionSystem: def __init__(self, StringStore string_table, dict labels_by_action): self.mem = Pool() self.n_moves = sum(len(labels) for labels in labels_by_action.values()) + self._is_valid = self.mem.alloc(self.n_moves, sizeof(bint)) moves = self.mem.alloc(self.n_moves, sizeof(Transition)) cdef int i = 0 cdef int label_id @@ -43,6 +44,9 @@ cdef class TransitionSystem: cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: raise NotImplementedError + + cdef bint* get_valid(self, const State* state) except NULL: + raise NotImplementedError cdef Transition best_gold(self, const weight_t* scores, const State* s, GoldParse gold) except *: