diff --git a/spacy/syntax/transition_system.pxd b/spacy/syntax/transition_system.pxd new file mode 100644 index 000000000..c588d984e --- /dev/null +++ b/spacy/syntax/transition_system.pxd @@ -0,0 +1,42 @@ +from cymem.cymem cimport Pool +from thinc.typedefs cimport weight_t + +from ..structs cimport TokenC +from ._state cimport State + + + +cdef struct Transition: + int clas + int move + int label + + weight_t score + int cost + + int (*get_cost)(const Transition* self, const State* state, const TokenC* gold) except -1 + + int (*is_valid)(const Transition* self, const State* state) except -1 + + int (*do)(const Transition* self, State* state) except -1 + + +ctypedef int (*get_cost_func_t)(const Transition* self, const State* state, + const TokenC* gold) except -1 + +ctypedef int (*is_valid_func_t)(const Transition* self, const State* state) except -1 + +ctypedef int (*do_func_t)(const Transition* self, State* state) except -1 + + +cdef class TransitionSystem: + cdef readonly dict label_ids + cdef Pool mem + cdef const Transition* c + + cdef Transition init_transition(self, int clas, int move, int label) except * + + cdef const Transition best_valid(self, const weight_t*, const State*) except * + + cdef const Transition best_gold(self, const weight_t*, const State*, + const TokenC*) except * diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx new file mode 100644 index 000000000..93849a59e --- /dev/null +++ b/spacy/syntax/transition_system.pyx @@ -0,0 +1,54 @@ +from cymem.cymem cimport Pool +from ._state cimport State +from ..structs cimport TokenC +from thinc.typedefs cimport weight_t + + +cdef weight_t MIN_SCORE = -90000 + + +cdef class TransitionSystem: + def __init__(self, dict labels_by_action): + self.mem = Pool() + self.n_moves = sum(len(labels) for labels in labels_by_action.items()) + moves = self.mem.alloc(self.n_moves, sizeof(Transition)) + cdef int i = 0 + self.label_ids = {} + for action, label_strs in sorted(labels_by_action.items()): + label_str = unicode(label_str) + label_id = self.label_ids.setdefault(label_str, len(self.label_ids)) + moves[i] = self.init_transition(i, action, label_id) + i += 1 + self.c = moves + + cdef Transition init_transition(self, int clas, int move, int label) except *: + raise NotImplementedError + + cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: + cdef Transition best + cdef weight_t score = MIN_SCORE + cdef int i + for i in range(self.n_moves): + if scores[i] > score and self.c[i].is_valid(&self.c[i], s): + best = self.c[i] + score = scores[i] + # Label Shift moves with the best Right-Arc label, for non-monotonic + # actions + #if best.move == SHIFT: + # score = MIN_SCORE + # for i in range(self.n_moves): + # if self.c[i].move == RIGHT and scores[i] > score: + # best.label = self.c[i].label + # score = scores[i] + return best + + cdef Transition best_gold(self, const weight_t* scores, const State* s, + const TokenC* gold) except *: + cdef Transition best + cdef weight_t score = MIN_SCORE + cdef int i + for i in range(self.n_moves): + if scores[i] > score and self.c[i].get_cost(&self.c[i], s, gold) == 0: + best = self.c[i] + score = scores[i] + return best