diff --git a/examples/keras_parikh_entailment/keras_decomposable_attention.py b/examples/keras_parikh_entailment/keras_decomposable_attention.py index 06a352f91..eb8a08f7d 100644 --- a/examples/keras_parikh_entailment/keras_decomposable_attention.py +++ b/examples/keras_parikh_entailment/keras_decomposable_attention.py @@ -6,7 +6,6 @@ from keras.layers import InputSpec, Layer, Input, Dense, merge from keras.layers import Activation, Dropout, Embedding, TimeDistributed from keras.layers import Bidirectional, GRU import keras.backend as K -import theano.tensor as T from keras.models import Sequential, Model, model_from_json from keras.regularizers import l2 from keras.optimizers import Adam @@ -106,8 +105,8 @@ class _Attention(object): def __call__(self, sent1, sent2): def _outer(AB): - att_ji = T.batched_dot(AB[1], AB[0].dimshuffle((0, 2, 1))) - return att_ji.dimshuffle((0, 2, 1)) + att_ji = K.batch_dot(AB[1], K.permute_dimensions(AB[0], (0, 2, 1))) + return K.permute_dimensions(att_ji,(0, 2, 1)) return merge( @@ -126,12 +125,12 @@ class _SoftAlignment(object): att = attmat[0] mat = attmat[1] if transpose: - att = att.dimshuffle((0, 2, 1)) + att = K.permute_dimensions(att,(0, 2, 1)) # 3d softmax e = K.exp(att - K.max(att, axis=-1, keepdims=True)) s = K.sum(e, axis=-1, keepdims=True) sm_att = e / s - return T.batched_dot(sm_att, mat) + return K.batch_dot(sm_att, mat) return merge([attention, sentence], mode=_normalize_attention, output_shape=(self.max_length, self.nr_hidden)) # Shape: (i, n) @@ -191,7 +190,7 @@ class _GlobalSumPooling1D(Layer): def call(self, x, mask=None): if mask is not None: - return K.sum(x * T.clip(mask, 0, 1), axis=1) + return K.sum(x * K.clip(mask, 0, 1), axis=1) else: return K.sum(x, axis=1) @@ -204,6 +203,7 @@ def test_build_model(): def test_fit_model(): + def _generate_X(nr_example, length, nr_vector): X1 = numpy.ndarray((nr_example, length), dtype='int32') X1 *= X1 < nr_vector @@ -212,6 +212,7 @@ def test_fit_model(): X2 *= X2 < nr_vector X2 *= 0 <= X2 return [X1, X2] + def _generate_Y(nr_example, nr_class): ys = numpy.zeros((nr_example, nr_class), dtype='int32') for i in range(nr_example):