diff --git a/examples/keras_parikh_entailment/__main__.py b/examples/keras_parikh_entailment/__main__.py index b0d5340f5..6886226bb 100644 --- a/examples/keras_parikh_entailment/__main__.py +++ b/examples/keras_parikh_entailment/__main__.py @@ -22,22 +22,13 @@ def train(model_dir, train_loc, dev_loc, shape, settings): print("Compiling network") model = build_model(get_embeddings(nlp.vocab), shape, settings) print("Processing texts...") - train_X1 = get_word_ids(list(nlp.pipe(train_texts1, n_threads=10, batch_size=10000)), - max_length=shape[0], - tree_truncate=settings['tree_truncate']) - train_X2 = get_word_ids(list(nlp.pipe(train_texts2, n_threads=10, batch_size=10000)), - max_length=shape[0], - tree_truncate=settings['tree_truncate']) - dev_X1 = get_word_ids(list(nlp.pipe(dev_texts1, n_threads=10, batch_size=10000)), - max_length=shape[0], - tree_truncate=settings['tree_truncate']) - dev_X2 = get_word_ids(list(nlp.pipe(dev_texts2, n_threads=10, batch_size=10000)), - max_length=shape[0], - tree_truncate=settings['tree_truncate']) - - print(train_X1.shape, train_X2.shape) - print(dev_X1.shape, dev_X2.shape) - print(train_labels.shape, dev_labels.shape) + Xs = [] + for texts in (train_texts1, train_texts2, dev_texts1, dev_texts2): + Xs.append(get_word_ids(list(nlp.pipe(texts, n_threads=20, batch_size=20000)), + max_length=shape[0], + rnn_encode=settings['gru_encode'], + tree_truncate=settings['tree_truncate'])) + train_X1, train_X2, dev_X1, dev_X2 = Xs print(settings) model.fit( [train_X1, train_X2], @@ -103,7 +94,7 @@ def read_snli(path): dropout=("Dropout level", "option", "d", float), learn_rate=("Learning rate", "option", "e", float), batch_size=("Batch size for neural network training", "option", "b", float), - nr_epoch=("Number of training epochs", "option", "i", float), + nr_epoch=("Number of training epochs", "option", "i", int), tree_truncate=("Truncate sentences by tree distance", "flag", "T", bool), gru_encode=("Encode sentences with bidirectional GRU", "flag", "E", bool), ) diff --git a/examples/keras_parikh_entailment/keras_decomposable_attention.py b/examples/keras_parikh_entailment/keras_decomposable_attention.py index eb8a08f7d..84663cf17 100644 --- a/examples/keras_parikh_entailment/keras_decomposable_attention.py +++ b/examples/keras_parikh_entailment/keras_decomposable_attention.py @@ -107,8 +107,6 @@ class _Attention(object): def _outer(AB): att_ji = K.batch_dot(AB[1], K.permute_dimensions(AB[0], (0, 2, 1))) return K.permute_dimensions(att_ji,(0, 2, 1)) - - return merge( [self.model(sent1), self.model(sent2)], mode=_outer, @@ -153,6 +151,7 @@ class _Comparison(object): def __call__(self, sent, align, **kwargs): result = self.model(merge([sent, align], mode='concat')) # Shape: (i, n) result = _GlobalSumPooling1D()(result, mask=self.words) + result = BatchNormalization()(result) return result diff --git a/examples/keras_parikh_entailment/spacy_hook.py b/examples/keras_parikh_entailment/spacy_hook.py index 71d6c3add..082e39ba9 100644 --- a/examples/keras_parikh_entailment/spacy_hook.py +++ b/examples/keras_parikh_entailment/spacy_hook.py @@ -40,16 +40,19 @@ def get_embeddings(vocab): return vectors -def get_word_ids(docs, tree_truncate=False, max_length=100): +def get_word_ids(docs, rnn_encode=False, tree_truncate=False, max_length=100): Xs = numpy.zeros((len(docs), max_length), dtype='int32') for i, doc in enumerate(docs): - j = 0 - queue = [sent.root for sent in doc.sents] + if tree_truncate: + queue = [sent.root for sent in doc.sents] + else: + queue = list(doc) words = [] while len(words) <= max_length and queue: word = queue.pop(0) - if word.has_vector and not word.is_punct and not word.is_space: + if rnn_encode or (word.has_vector and not word.is_punct and not word.is_space): words.append(word) + if tree_truncate: queue.extend(list(word.lefts)) queue.extend(list(word.rights)) words.sort()