diff --git a/spacy/syntax/transition_system.pxd b/spacy/syntax/transition_system.pxd index 38bc91605..b25e7513a 100644 --- a/spacy/syntax/transition_system.pxd +++ b/spacy/syntax/transition_system.pxd @@ -31,14 +31,12 @@ ctypedef int (*do_func_t)(StateClass state, int label) nogil cdef class TransitionSystem: cdef Pool mem cdef StringStore strings - cdef const Transition* c - cdef bint* _is_valid + cdef Transition* c cdef readonly int n_moves + cdef int _size cdef public int root_label cdef public freqs - cdef object _labels_by_action - cdef int initialize_state(self, StateClass state) except -1 cdef int finalize_state(self, StateClass state) nogil diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index 5de3513e0..ef12e0074 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -16,20 +16,17 @@ class OracleError(Exception): cdef class TransitionSystem: def __init__(self, StringStore string_table, dict labels_by_action, _freqs=None): - self._labels_by_action = labels_by_action self.mem = Pool() - self.n_moves = sum(len(labels) for labels in labels_by_action.values()) - self._is_valid = self.mem.alloc(self.n_moves, sizeof(bint)) - moves = self.mem.alloc(self.n_moves, sizeof(Transition)) - cdef int i = 0 - cdef int label_id self.strings = string_table + self.n_moves = 0 + self._size = 100 + + self.c = self.mem.alloc(self._size, sizeof(Transition)) + for action, label_strs in sorted(labels_by_action.items()): for label_str in sorted(label_strs): - label_id = self.strings[unicode(label_str)] if label_str else 0 - moves[i] = self.init_transition(i, int(action), label_id) - i += 1 - self.c = moves + self.add_action(int(action), label_str) + self.root_label = self.strings['ROOT'] self.freqs = {} if _freqs is None else _freqs for attr in (TAG, HEAD, DEP, ENT_TYPE, ENT_IOB): @@ -41,8 +38,13 @@ cdef class TransitionSystem: self.freqs[HEAD][-i] = 1 def __reduce__(self): + labels_by_action = {} + cdef Transition t + for trans in self.c[:self.n_moves]: + label_str = self.strings[trans.label] + labels_by_action.setdefault(trans.move, []).append(label_str) return (self.__class__, - (self.strings, self._labels_by_action, self.freqs), + (self.strings, labels_by_action, self.freqs), None, None) cdef int initialize_state(self, StateClass state) except -1: @@ -78,3 +80,14 @@ cdef class TransitionSystem: costs[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label) else: costs[i] = 9000 + + def add_action(self, int action, label): + if self.n_moves >= self._size: + self._size *= 2 + self.c = self.mem.realloc(self.c, self._size * sizeof(self.c[0])) + + if not isinstance(label, int): + label = self.strings[label] + + self.c[self.n_moves] = self.init_transition(self.n_moves, action, label) + self.n_moves += 1