mirror of https://github.com/explosion/spaCy.git
Restore support for deeper networks in parser
This commit is contained in:
parent
e27262f431
commit
964707d795
|
@ -336,9 +336,25 @@ cdef class Parser:
|
||||||
|
|
||||||
feat_weights = state2vec.get_feat_weights()
|
feat_weights = state2vec.get_feat_weights()
|
||||||
cdef int i
|
cdef int i
|
||||||
|
cdef np.ndarray token_ids = numpy.zeros((nr_state, nr_feat), dtype='i')
|
||||||
|
cdef np.ndarray is_valid = numpy.zeros((nr_state, nr_feat), dtype='i')
|
||||||
|
cdef np.ndarray scores
|
||||||
|
c_token_ids = <int*>token_ids.data
|
||||||
|
c_is_valid = <int*>is_valid.data
|
||||||
while not next_step.empty():
|
while not next_step.empty():
|
||||||
for i in cython.parallel.prange(next_step.size(), num_threads=4, nogil=True):
|
for i in range(next_step.size()):
|
||||||
self._parse_step(next_step[i], feat_weights, nr_class, nr_feat)
|
st = next_step[i]
|
||||||
|
st.set_context_tokens(&c_token_ids[i*nr_feat], nr_feat)
|
||||||
|
self.moves.set_valid(&c_is_valid[i*nr_class], st)
|
||||||
|
vectors = state2vec.begin_update(token_ids[:next_step.size()])
|
||||||
|
scores = vec2scores(vectors)
|
||||||
|
c_scores = <float*>scores.data
|
||||||
|
for i in range(next_step.size()):
|
||||||
|
st = next_step[i]
|
||||||
|
guess = arg_max_if_valid(
|
||||||
|
&c_scores[i*nr_class], &c_is_valid[i*nr_class], nr_class)
|
||||||
|
action = self.moves.c[guess]
|
||||||
|
action.do(st, action.label)
|
||||||
this_step, next_step = next_step, this_step
|
this_step, next_step = next_step, this_step
|
||||||
next_step.clear()
|
next_step.clear()
|
||||||
for st in this_step:
|
for st in this_step:
|
||||||
|
@ -349,6 +365,9 @@ cdef class Parser:
|
||||||
cdef void _parse_step(self, StateC* state,
|
cdef void _parse_step(self, StateC* state,
|
||||||
const float* feat_weights,
|
const float* feat_weights,
|
||||||
int nr_class, int nr_feat) nogil:
|
int nr_class, int nr_feat) nogil:
|
||||||
|
'''This only works with no hidden layers -- fast but inaccurate'''
|
||||||
|
#for i in cython.parallel.prange(next_step.size(), num_threads=4, nogil=True):
|
||||||
|
# self._parse_step(next_step[i], feat_weights, nr_class, nr_feat)
|
||||||
token_ids = <int*>calloc(nr_feat, sizeof(int))
|
token_ids = <int*>calloc(nr_feat, sizeof(int))
|
||||||
scores = <float*>calloc(nr_class, sizeof(float))
|
scores = <float*>calloc(nr_class, sizeof(float))
|
||||||
is_valid = <int*>calloc(nr_class, sizeof(int))
|
is_valid = <int*>calloc(nr_class, sizeof(int))
|
||||||
|
|
Loading…
Reference in New Issue