Try to make sum_state_features faster

This commit is contained in:
Matthew Honnibal 2018-03-27 10:08:38 +00:00
parent 987e1533a4
commit 25280b7013
1 changed files with 5 additions and 4 deletions

View File

@ -165,16 +165,17 @@ cdef void sum_state_features(float* output,
cdef const float* feature cdef const float* feature
padding = cached padding = cached
cached += F * O cached += F * O
cdef int id_stride = F*O
cdef float one = 1.
for b in range(B): for b in range(B):
for f in range(F): for f in range(F):
if token_ids[f] < 0: if token_ids[f] < 0:
feature = &padding[f*O] feature = &padding[f*O]
else: else:
idx = token_ids[f] * F * O + f*O idx = token_ids[f] * id_stride + f*O
feature = &cached[idx] feature = &cached[idx]
VecVec.add_i(output, openblas.simple_axpy(&output[b*O], O,
feature, 1., O) feature, one)
output += O
token_ids += F token_ids += F