Adjust call to scatter_add for the new version

This commit is contained in:
Matthew Honnibal 2017-10-27 01:16:55 +00:00
parent 35977bdbb9
commit bb25bdcd92
1 changed files with 5 additions and 6 deletions

View File

@ -1,5 +1,4 @@
# cython: infer_types=True # cython: infer_types=True
# cython: profile=True
# cython: cdivision=True # cython: cdivision=True
# cython: boundscheck=False # cython: boundscheck=False
# coding: utf-8 # coding: utf-8
@ -435,8 +434,7 @@ cdef class Parser:
cdef int nr_hidden = hidden_weights.shape[0] cdef int nr_hidden = hidden_weights.shape[0]
cdef int nr_task = states.size() cdef int nr_task = states.size()
with nogil: with nogil:
for i in cython.parallel.prange(nr_task, num_threads=2, for i in range(nr_task):
schedule='guided'):
self._parseC(states[i], self._parseC(states[i],
feat_weights, bias, hW, hb, feat_weights, bias, hW, hb,
nr_class, nr_hidden, nr_feat, nr_piece) nr_class, nr_hidden, nr_feat, nr_piece)
@ -697,9 +695,10 @@ cdef class Parser:
xp = get_array_module(d_tokvecs) xp = get_array_module(d_tokvecs)
for ids, d_vector, bp_vector in backprops: for ids, d_vector, bp_vector in backprops:
d_state_features = bp_vector(d_vector, sgd=sgd) d_state_features = bp_vector(d_vector, sgd=sgd)
mask = ids >= 0 ids = ids.flatten()
d_state_features *= mask.reshape(ids.shape + (1,)) d_state_features = d_state_features.reshape(
self.model[0].ops.scatter_add(d_tokvecs, ids * mask, (ids.size, d_state_features.shape[2]))
self.model[0].ops.scatter_add(d_tokvecs, ids,
d_state_features) d_state_features)
bp_tokvecs(d_tokvecs, sgd=sgd) bp_tokvecs(d_tokvecs, sgd=sgd)