diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index fb7099022..5d6f51538 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -633,10 +633,9 @@ cdef class Parser: xp = get_array_module(d_tokvecs) for ids, d_vector, bp_vector in backprops: d_state_features = bp_vector(d_vector, sgd=sgd) - mask = ids >= 0 - indices = xp.nonzero(mask) - self.model[0].ops.scatter_add(d_tokvecs, ids[indices], - d_state_features[indices]) + mask = (ids >= 0).reshape((ids.shape[0], ids.shape[1], 1)) + self.model[0].ops.scatter_add(d_tokvecs, ids, + d_state_features * mask) @property def move_names(self):