* Allow users to add_label, in order to extend the entity recogniser to new classes. Does not by itself add a class to the model

This commit is contained in:
Matthew Honnibal 2016-01-19 19:09:33 +01:00
parent c8e0011ebc
commit 151aa0b0e2
2 changed files with 26 additions and 15 deletions

View File

@ -31,14 +31,12 @@ ctypedef int (*do_func_t)(StateClass state, int label) nogil
cdef class TransitionSystem: cdef class TransitionSystem:
cdef Pool mem cdef Pool mem
cdef StringStore strings cdef StringStore strings
cdef const Transition* c cdef Transition* c
cdef bint* _is_valid
cdef readonly int n_moves cdef readonly int n_moves
cdef int _size
cdef public int root_label cdef public int root_label
cdef public freqs cdef public freqs
cdef object _labels_by_action
cdef int initialize_state(self, StateClass state) except -1 cdef int initialize_state(self, StateClass state) except -1
cdef int finalize_state(self, StateClass state) nogil cdef int finalize_state(self, StateClass state) nogil

View File

@ -16,20 +16,17 @@ class OracleError(Exception):
cdef class TransitionSystem: cdef class TransitionSystem:
def __init__(self, StringStore string_table, dict labels_by_action, _freqs=None): def __init__(self, StringStore string_table, dict labels_by_action, _freqs=None):
self._labels_by_action = labels_by_action
self.mem = Pool() self.mem = Pool()
self.n_moves = sum(len(labels) for labels in labels_by_action.values())
self._is_valid = <bint*>self.mem.alloc(self.n_moves, sizeof(bint))
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
cdef int i = 0
cdef int label_id
self.strings = string_table self.strings = string_table
self.n_moves = 0
self._size = 100
self.c = <Transition*>self.mem.alloc(self._size, sizeof(Transition))
for action, label_strs in sorted(labels_by_action.items()): for action, label_strs in sorted(labels_by_action.items()):
for label_str in sorted(label_strs): for label_str in sorted(label_strs):
label_id = self.strings[unicode(label_str)] if label_str else 0 self.add_action(int(action), label_str)
moves[i] = self.init_transition(i, int(action), label_id)
i += 1
self.c = moves
self.root_label = self.strings['ROOT'] self.root_label = self.strings['ROOT']
self.freqs = {} if _freqs is None else _freqs self.freqs = {} if _freqs is None else _freqs
for attr in (TAG, HEAD, DEP, ENT_TYPE, ENT_IOB): for attr in (TAG, HEAD, DEP, ENT_TYPE, ENT_IOB):
@ -41,8 +38,13 @@ cdef class TransitionSystem:
self.freqs[HEAD][-i] = 1 self.freqs[HEAD][-i] = 1
def __reduce__(self): def __reduce__(self):
labels_by_action = {}
cdef Transition t
for trans in self.c[:self.n_moves]:
label_str = self.strings[trans.label]
labels_by_action.setdefault(trans.move, []).append(label_str)
return (self.__class__, return (self.__class__,
(self.strings, self._labels_by_action, self.freqs), (self.strings, labels_by_action, self.freqs),
None, None) None, None)
cdef int initialize_state(self, StateClass state) except -1: cdef int initialize_state(self, StateClass state) except -1:
@ -78,3 +80,14 @@ cdef class TransitionSystem:
costs[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label) costs[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
else: else:
costs[i] = 9000 costs[i] = 9000
def add_action(self, int action, label):
if self.n_moves >= self._size:
self._size *= 2
self.c = <Transition*>self.mem.realloc(self.c, self._size * sizeof(self.c[0]))
if not isinstance(label, int):
label = self.strings[label]
self.c[self.n_moves] = self.init_transition(self.n_moves, action, label)
self.n_moves += 1