mirror of https://github.com/explosion/spaCy.git
Adjust call to scatter_add for the new version
This commit is contained in:
parent
35977bdbb9
commit
bb25bdcd92
|
@ -1,5 +1,4 @@
|
|||
# cython: infer_types=True
|
||||
# cython: profile=True
|
||||
# cython: cdivision=True
|
||||
# cython: boundscheck=False
|
||||
# coding: utf-8
|
||||
|
@ -435,8 +434,7 @@ cdef class Parser:
|
|||
cdef int nr_hidden = hidden_weights.shape[0]
|
||||
cdef int nr_task = states.size()
|
||||
with nogil:
|
||||
for i in cython.parallel.prange(nr_task, num_threads=2,
|
||||
schedule='guided'):
|
||||
for i in range(nr_task):
|
||||
self._parseC(states[i],
|
||||
feat_weights, bias, hW, hb,
|
||||
nr_class, nr_hidden, nr_feat, nr_piece)
|
||||
|
@ -697,9 +695,10 @@ cdef class Parser:
|
|||
xp = get_array_module(d_tokvecs)
|
||||
for ids, d_vector, bp_vector in backprops:
|
||||
d_state_features = bp_vector(d_vector, sgd=sgd)
|
||||
mask = ids >= 0
|
||||
d_state_features *= mask.reshape(ids.shape + (1,))
|
||||
self.model[0].ops.scatter_add(d_tokvecs, ids * mask,
|
||||
ids = ids.flatten()
|
||||
d_state_features = d_state_features.reshape(
|
||||
(ids.size, d_state_features.shape[2]))
|
||||
self.model[0].ops.scatter_add(d_tokvecs, ids,
|
||||
d_state_features)
|
||||
bp_tokvecs(d_tokvecs, sgd=sgd)
|
||||
|
||||
|
|
Loading…
Reference in New Issue