From 7edb2e171181f0f49fb4b1f54326fa9e2b97373b Mon Sep 17 00:00:00 2001 From: svlandeg Date: Mon, 20 May 2019 11:58:48 +0200 Subject: [PATCH] fix convolution layer --- .../pipeline/wiki_entity_linking/train_el.py | 44 +++++++++++-------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/examples/pipeline/wiki_entity_linking/train_el.py b/examples/pipeline/wiki_entity_linking/train_el.py index 312e50cad..2d7ede48d 100644 --- a/examples/pipeline/wiki_entity_linking/train_el.py +++ b/examples/pipeline/wiki_entity_linking/train_el.py @@ -12,9 +12,9 @@ from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, from spacy._ml import SpacyVectors, create_default_optimizer, zero_init, logistic -from thinc.api import chain, concatenate, flatten_add_lengths, clone +from thinc.api import chain, concatenate, flatten_add_lengths, clone, with_flatten from thinc.v2v import Model, Maxout, Affine -from thinc.t2v import Pooling, mean_pool +from thinc.t2v import Pooling, mean_pool, sum_pool from thinc.t2t import ParametricAttention from thinc.misc import Residual from thinc.misc import LayerNorm as LN @@ -96,13 +96,13 @@ class EL_Model: try: # if to_print: # print() - # print(article_count, "Training on article", article_id) + print(article_count, "Training on article", article_id) article_count += 1 article_docs = list() entities = list() golds = list() for inst_cluster in inst_cluster_set: - if instance_pos_count < 2: # TODO remove + if instance_pos_count < 2: # TODO del article_docs.append(train_doc[article_id]) entities.append(train_pos.get(inst_cluster)) golds.append(float(1.0)) @@ -228,16 +228,23 @@ class EL_Model: @staticmethod def _encoder(in_width, hidden_width): + conv_depth = 1 + cnn_maxout_pieces = 3 + with Model.define_operators({">>": chain}): + convolution = Residual((ExtractWindow(nW=1) >> LN(Maxout(in_width, in_width * 3, pieces=cnn_maxout_pieces)))) + encoder = SpacyVectors \ - >> flatten_add_lengths \ - >> ParametricAttention(in_width)\ - >> Pooling(mean_pool) \ - >> (ExtractWindow(nW=1) >> LN(Maxout(in_width, in_width * 3))) \ - >> zero_init(Affine(hidden_width, in_width, drop_factor=0.0)) + >> with_flatten(LN(Maxout(in_width, in_width)) >> convolution ** conv_depth, pad=conv_depth) \ + >> flatten_add_lengths \ + >> ParametricAttention(in_width)\ + >> Pooling(mean_pool) \ + >> Residual(zero_init(Maxout(in_width, in_width))) \ + >> zero_init(Affine(hidden_width, in_width, drop_factor=0.0)) # TODO: ReLu instead of LN(Maxout) ? # TODO: more convolutions ? + # sum_pool or mean_pool ? return encoder @@ -261,16 +268,17 @@ class EL_Model: print(doc_encoding) doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=self.DROP) - entity_encodings, bp_encoding = self.entity_encoder.begin_update(entities, drop=self.DROP) - concat_encodings = [list(entity_encodings[i]) + list(doc_encodings[i]) for i in range(len(entities))] - print("doc_encodings", len(doc_encodings), doc_encodings) + + entity_encodings, bp_encoding = self.entity_encoder.begin_update(entities, drop=self.DROP) print("entity_encodings", len(entity_encodings), entity_encodings) - print("concat_encodings", len(concat_encodings), concat_encodings) + + concat_encodings = [list(entity_encodings[i]) + list(doc_encodings[i]) for i in range(len(entities))] + # print("concat_encodings", len(concat_encodings), concat_encodings) predictions, bp_model = self.model.begin_update(np.asarray(concat_encodings), drop=self.DROP) - print("predictions", predictions) predictions = self.model.ops.flatten(predictions) + print("predictions", predictions) golds = self.model.ops.asarray(golds) loss, d_scores = self.get_loss(predictions, golds) @@ -287,15 +295,15 @@ class EL_Model: d_scores = d_scores.reshape((-1, 1)) d_scores = d_scores.astype(np.float32) - print("d_scores", d_scores) + # print("d_scores", d_scores) model_gradient = bp_model(d_scores, sgd=self.sgd) - print("model_gradient", model_gradient) + # print("model_gradient", model_gradient) doc_gradient = [x[0:self.ARTICLE_WIDTH] for x in model_gradient] - print("doc_gradient", doc_gradient) + # print("doc_gradient", doc_gradient) entity_gradient = [x[self.ARTICLE_WIDTH:] for x in model_gradient] - print("entity_gradient", entity_gradient) + # print("entity_gradient", entity_gradient) bp_doc(doc_gradient) bp_encoding(entity_gradient)