mirror of https://github.com/explosion/spaCy.git
* Allow users to add_label, in order to extend the entity recogniser to new classes. Does not by itself add a class to the model
This commit is contained in:
parent
c8e0011ebc
commit
151aa0b0e2
|
@ -31,14 +31,12 @@ ctypedef int (*do_func_t)(StateClass state, int label) nogil
|
||||||
cdef class TransitionSystem:
|
cdef class TransitionSystem:
|
||||||
cdef Pool mem
|
cdef Pool mem
|
||||||
cdef StringStore strings
|
cdef StringStore strings
|
||||||
cdef const Transition* c
|
cdef Transition* c
|
||||||
cdef bint* _is_valid
|
|
||||||
cdef readonly int n_moves
|
cdef readonly int n_moves
|
||||||
|
cdef int _size
|
||||||
cdef public int root_label
|
cdef public int root_label
|
||||||
cdef public freqs
|
cdef public freqs
|
||||||
|
|
||||||
cdef object _labels_by_action
|
|
||||||
|
|
||||||
cdef int initialize_state(self, StateClass state) except -1
|
cdef int initialize_state(self, StateClass state) except -1
|
||||||
cdef int finalize_state(self, StateClass state) nogil
|
cdef int finalize_state(self, StateClass state) nogil
|
||||||
|
|
||||||
|
|
|
@ -16,20 +16,17 @@ class OracleError(Exception):
|
||||||
|
|
||||||
cdef class TransitionSystem:
|
cdef class TransitionSystem:
|
||||||
def __init__(self, StringStore string_table, dict labels_by_action, _freqs=None):
|
def __init__(self, StringStore string_table, dict labels_by_action, _freqs=None):
|
||||||
self._labels_by_action = labels_by_action
|
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self.n_moves = sum(len(labels) for labels in labels_by_action.values())
|
|
||||||
self._is_valid = <bint*>self.mem.alloc(self.n_moves, sizeof(bint))
|
|
||||||
moves = <Transition*>self.mem.alloc(self.n_moves, sizeof(Transition))
|
|
||||||
cdef int i = 0
|
|
||||||
cdef int label_id
|
|
||||||
self.strings = string_table
|
self.strings = string_table
|
||||||
|
self.n_moves = 0
|
||||||
|
self._size = 100
|
||||||
|
|
||||||
|
self.c = <Transition*>self.mem.alloc(self._size, sizeof(Transition))
|
||||||
|
|
||||||
for action, label_strs in sorted(labels_by_action.items()):
|
for action, label_strs in sorted(labels_by_action.items()):
|
||||||
for label_str in sorted(label_strs):
|
for label_str in sorted(label_strs):
|
||||||
label_id = self.strings[unicode(label_str)] if label_str else 0
|
self.add_action(int(action), label_str)
|
||||||
moves[i] = self.init_transition(i, int(action), label_id)
|
|
||||||
i += 1
|
|
||||||
self.c = moves
|
|
||||||
self.root_label = self.strings['ROOT']
|
self.root_label = self.strings['ROOT']
|
||||||
self.freqs = {} if _freqs is None else _freqs
|
self.freqs = {} if _freqs is None else _freqs
|
||||||
for attr in (TAG, HEAD, DEP, ENT_TYPE, ENT_IOB):
|
for attr in (TAG, HEAD, DEP, ENT_TYPE, ENT_IOB):
|
||||||
|
@ -41,8 +38,13 @@ cdef class TransitionSystem:
|
||||||
self.freqs[HEAD][-i] = 1
|
self.freqs[HEAD][-i] = 1
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
|
labels_by_action = {}
|
||||||
|
cdef Transition t
|
||||||
|
for trans in self.c[:self.n_moves]:
|
||||||
|
label_str = self.strings[trans.label]
|
||||||
|
labels_by_action.setdefault(trans.move, []).append(label_str)
|
||||||
return (self.__class__,
|
return (self.__class__,
|
||||||
(self.strings, self._labels_by_action, self.freqs),
|
(self.strings, labels_by_action, self.freqs),
|
||||||
None, None)
|
None, None)
|
||||||
|
|
||||||
cdef int initialize_state(self, StateClass state) except -1:
|
cdef int initialize_state(self, StateClass state) except -1:
|
||||||
|
@ -78,3 +80,14 @@ cdef class TransitionSystem:
|
||||||
costs[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
|
costs[i] = self.c[i].get_cost(stcls, &gold.c, self.c[i].label)
|
||||||
else:
|
else:
|
||||||
costs[i] = 9000
|
costs[i] = 9000
|
||||||
|
|
||||||
|
def add_action(self, int action, label):
|
||||||
|
if self.n_moves >= self._size:
|
||||||
|
self._size *= 2
|
||||||
|
self.c = <Transition*>self.mem.realloc(self.c, self._size * sizeof(self.c[0]))
|
||||||
|
|
||||||
|
if not isinstance(label, int):
|
||||||
|
label = self.strings[label]
|
||||||
|
|
||||||
|
self.c[self.n_moves] = self.init_transition(self.n_moves, action, label)
|
||||||
|
self.n_moves += 1
|
||||||
|
|
Loading…
Reference in New Issue