spaCy/examples/keras_parikh_entailment/keras_decomposable_attentio...

153 lines
4.7 KiB
Python

# Semantic entailment/similarity with decomposable attention (using spaCy and Keras)
# Practical state-of-the-art textual entailment with spaCy and Keras
import numpy as np
from keras import layers, Model, models, optimizers
from keras import backend as K
def build_model(vectors, shape, settings):
max_length, nr_hidden, nr_class = shape
input1 = layers.Input(shape=(max_length,), dtype="int32", name="words1")
input2 = layers.Input(shape=(max_length,), dtype="int32", name="words2")
# embeddings (projected)
embed = create_embedding(vectors, max_length, nr_hidden)
a = embed(input1)
b = embed(input2)
# step 1: attend
F = create_feedforward(nr_hidden)
att_weights = layers.dot([F(a), F(b)], axes=-1)
G = create_feedforward(nr_hidden)
if settings["entail_dir"] == "both":
norm_weights_a = layers.Lambda(normalizer(1))(att_weights)
norm_weights_b = layers.Lambda(normalizer(2))(att_weights)
alpha = layers.dot([norm_weights_a, a], axes=1)
beta = layers.dot([norm_weights_b, b], axes=1)
# step 2: compare
comp1 = layers.concatenate([a, beta])
comp2 = layers.concatenate([b, alpha])
v1 = layers.TimeDistributed(G)(comp1)
v2 = layers.TimeDistributed(G)(comp2)
# step 3: aggregate
v1_sum = layers.Lambda(sum_word)(v1)
v2_sum = layers.Lambda(sum_word)(v2)
concat = layers.concatenate([v1_sum, v2_sum])
elif settings["entail_dir"] == "left":
norm_weights_a = layers.Lambda(normalizer(1))(att_weights)
alpha = layers.dot([norm_weights_a, a], axes=1)
comp2 = layers.concatenate([b, alpha])
v2 = layers.TimeDistributed(G)(comp2)
v2_sum = layers.Lambda(sum_word)(v2)
concat = v2_sum
else:
norm_weights_b = layers.Lambda(normalizer(2))(att_weights)
beta = layers.dot([norm_weights_b, b], axes=1)
comp1 = layers.concatenate([a, beta])
v1 = layers.TimeDistributed(G)(comp1)
v1_sum = layers.Lambda(sum_word)(v1)
concat = v1_sum
H = create_feedforward(nr_hidden)
out = H(concat)
out = layers.Dense(nr_class, activation="softmax")(out)
model = Model([input1, input2], out)
model.compile(
optimizer=optimizers.Adam(lr=settings["lr"]),
loss="categorical_crossentropy",
metrics=["accuracy"],
)
return model
def create_embedding(vectors, max_length, projected_dim):
return models.Sequential(
[
layers.Embedding(
vectors.shape[0],
vectors.shape[1],
input_length=max_length,
weights=[vectors],
trainable=False,
),
layers.TimeDistributed(
layers.Dense(projected_dim, activation=None, use_bias=False)
),
]
)
def create_feedforward(num_units=200, activation="relu", dropout_rate=0.2):
return models.Sequential(
[
layers.Dense(num_units, activation=activation),
layers.Dropout(dropout_rate),
layers.Dense(num_units, activation=activation),
layers.Dropout(dropout_rate),
]
)
def normalizer(axis):
def _normalize(att_weights):
exp_weights = K.exp(att_weights)
sum_weights = K.sum(exp_weights, axis=axis, keepdims=True)
return exp_weights / sum_weights
return _normalize
def sum_word(x):
return K.sum(x, axis=1)
def test_build_model():
vectors = np.ndarray((100, 8), dtype="float32")
shape = (10, 16, 3)
settings = {"lr": 0.001, "dropout": 0.2, "gru_encode": True, "entail_dir": "both"}
model = build_model(vectors, shape, settings)
def test_fit_model():
def _generate_X(nr_example, length, nr_vector):
X1 = np.ndarray((nr_example, length), dtype="int32")
X1 *= X1 < nr_vector
X1 *= 0 <= X1
X2 = np.ndarray((nr_example, length), dtype="int32")
X2 *= X2 < nr_vector
X2 *= 0 <= X2
return [X1, X2]
def _generate_Y(nr_example, nr_class):
ys = np.zeros((nr_example, nr_class), dtype="int32")
for i in range(nr_example):
ys[i, i % nr_class] = 1
return ys
vectors = np.ndarray((100, 8), dtype="float32")
shape = (10, 16, 3)
settings = {"lr": 0.001, "dropout": 0.2, "gru_encode": True, "entail_dir": "both"}
model = build_model(vectors, shape, settings)
train_X = _generate_X(20, shape[0], vectors.shape[0])
train_Y = _generate_Y(20, shape[2])
dev_X = _generate_X(15, shape[0], vectors.shape[0])
dev_Y = _generate_Y(15, shape[2])
model.fit(train_X, train_Y, validation_data=(dev_X, dev_Y), epochs=5, batch_size=4)
__all__ = [build_model]