Fix non-maxout parser

This commit is contained in:
Matthew Honnibal 2017-10-23 18:16:23 +02:00
parent 1036798155
commit e7556ff048
1 changed files with 6 additions and 3 deletions

View File

@ -117,7 +117,7 @@ cdef class precompute_hiddens:
cached = gpu_cached cached = gpu_cached
self.nF = cached.shape[1] self.nF = cached.shape[1]
self.nP = getattr(lower_model, 'nP', 1) self.nP = getattr(lower_model, 'nP', 1)
self.nO = cached.shape[2] // self.nP self.nO = cached.shape[2]
self.ops = lower_model.ops self.ops = lower_model.ops
self.bias = lower_model.b self.bias = lower_model.b
self._is_synchronized = False self._is_synchronized = False
@ -150,7 +150,7 @@ cdef class precompute_hiddens:
sum_state_features(<float*>state_vector.data, sum_state_features(<float*>state_vector.data,
feat_weights, &ids[0,0], feat_weights, &ids[0,0],
token_ids.shape[0], self.nF, self.nO*self.nP) token_ids.shape[0], self.nF, self.nO*self.nP)
state_vector += self.bias.ravel() state_vector += self.bias
state_vector, bp_nonlinearity = self._nonlinearity(state_vector) state_vector, bp_nonlinearity = self._nonlinearity(state_vector)
def backward(d_state_vector, sgd=None): def backward(d_state_vector, sgd=None):
@ -164,6 +164,7 @@ cdef class precompute_hiddens:
def _nonlinearity(self, state_vector): def _nonlinearity(self, state_vector):
if self.nP == 1: if self.nP == 1:
state_vector = state_vector.reshape(state_vector.shape[:-1])
mask = state_vector >= 0. mask = state_vector >= 0.
state_vector *= mask state_vector *= mask
else: else:
@ -171,7 +172,9 @@ cdef class precompute_hiddens:
def backprop_nonlinearity(d_best, sgd=None): def backprop_nonlinearity(d_best, sgd=None):
if self.nP == 1: if self.nP == 1:
return d_best * mask d_best *= mask
d_best = d_best.reshape((d_best.shape + (1,)))
return d_best
else: else:
return self.ops.backprop_maxout(d_best, mask, self.nP) return self.ops.backprop_maxout(d_best, mask, self.nP)
return state_vector, backprop_nonlinearity return state_vector, backprop_nonlinearity