diff --git a/examples/pipeline/wiki_entity_linking/train_el.py b/examples/pipeline/wiki_entity_linking/train_el.py index c91058d5f..cfd17bd78 100644 --- a/examples/pipeline/wiki_entity_linking/train_el.py +++ b/examples/pipeline/wiki_entity_linking/train_el.py @@ -4,17 +4,17 @@ from __future__ import unicode_literals import os import datetime from os import listdir +import numpy as np from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator -from spacy._ml import SpacyVectors, create_default_optimizer, zero_init, cosine +from spacy._ml import SpacyVectors, create_default_optimizer, zero_init -from thinc.api import chain +from thinc.api import chain, flatten_add_lengths, with_getitem, clone, with_flatten from thinc.v2v import Model, Maxout, Softmax, Affine, ReLu -from thinc.api import flatten_add_lengths from thinc.t2v import Pooling, sum_pool, mean_pool from thinc.t2t import ExtractWindow, ParametricAttention -from thinc.misc import Residual +from thinc.misc import Residual, LayerNorm as LN """ TODO: this code needs to be implemented in pipes.pyx""" @@ -29,8 +29,8 @@ class EL_Model(): self.nlp = nlp self.kb = kb - self.entity_encoder = self._simple_encoder(width=300) - self.article_encoder = self._simple_encoder(width=300) + self.entity_encoder = self._simple_encoder(in_width=300, out_width=96) + self.article_encoder = self._simple_encoder(in_width=300, out_width=96) def train_model(self, training_dir, entity_descr_output, limit=None, to_print=True): instances, pos_entities, neg_entities, doc_by_article = self._get_training_data(training_dir, @@ -61,13 +61,36 @@ class EL_Model(): # elif not neg_exs: # print("Weird. Couldn't find neg examples for", inst_cluster) - def _simple_encoder(self, width): - with Model.define_operators({">>": chain}): + def _simple_encoder(self, in_width, out_width): + conv_depth = 1 + cnn_maxout_pieces = 3 + with Model.define_operators({">>": chain, "**": clone}): + # encoder = SpacyVectors \ + # >> flatten_add_lengths \ + # >> ParametricAttention(in_width)\ + # >> Pooling(mean_pool) \ + # >> Residual(zero_init(Maxout(in_width, in_width))) \ + # >> zero_init(Affine(out_width, in_width, drop_factor=0.0)) encoder = SpacyVectors \ - >> flatten_add_lengths \ - >> ParametricAttention(width)\ - >> Pooling(sum_pool) \ - >> Residual(zero_init(Maxout(width, width))) + >> flatten_add_lengths \ + >> with_getitem(0, Affine(in_width, in_width)) \ + >> ParametricAttention(in_width) \ + >> Pooling(sum_pool) \ + >> Residual(ReLu(in_width, in_width)) ** conv_depth \ + >> zero_init(Affine(out_width, in_width, drop_factor=0.0)) + + # >> zero_init(Affine(nr_class, width, drop_factor=0.0)) + # >> logistic + + # convolution = Residual( + # ExtractWindow(nW=1) + # >> LN(Maxout(width, width * 3, pieces=cnn_maxout_pieces)) + # ) + + # embed = SpacyVectors >> LN(Maxout(width, width, pieces=3)) + + # encoder = SpacyVectors >> flatten_add_lengths >> convolution ** conv_depth + # encoder = with_flatten(embed >> convolution ** conv_depth, pad=conv_depth) return encoder @@ -80,25 +103,56 @@ class EL_Model(): doc_encoding, article_bp = self.article_encoder.begin_update([article_doc], drop=drop) true_entity_encoding, true_entity_bp = self.entity_encoder.begin_update([true_entity], drop=drop) - # true_similarity = cosine(true_entity_encoding, doc_encoding) - # print("true_similarity", true_similarity) + # print("encoding dim", len(true_entity_encoding[0])) - # for false_entity in false_entities: - # false_entity_encoding, false_entity_bp = self.entity_encoder.begin_update([false_entity], drop=drop) - # false_similarity = cosine(false_entity_encoding, doc_encoding) - # print("false_similarity", false_similarity) + consensus_encoding = self._calculate_consensus(doc_encoding, true_entity_encoding) + consensus_encoding_t = consensus_encoding.transpose() - # print("entity/article output dim", len(entity_encoding[0]), len(doc_encoding[0])) + doc_mse, doc_diffs = self._calculate_similarity(doc_encoding, consensus_encoding) - mse, diffs = self._calculate_similarity(true_entity_encoding, doc_encoding) + entity_mses = list() + + true_mse, true_diffs = self._calculate_similarity(true_entity_encoding, consensus_encoding) + # print("true_mse", true_mse) + # print("true_diffs", true_diffs) + entity_mses.append(true_mse) + # true_exp = np.exp(true_entity_encoding.dot(consensus_encoding_t)) + # print("true_exp", true_exp) + + # false_exp_sum = 0 + + for false_entity in false_entities: + false_entity_encoding, false_entity_bp = self.entity_encoder.begin_update([false_entity], drop=drop) + false_mse, false_diffs = self._calculate_similarity(false_entity_encoding, consensus_encoding) + # print("false_mse", false_mse) + # false_exp = np.exp(false_entity_encoding.dot(consensus_encoding_t)) + # print("false_exp", false_exp) + # print("false_diffs", false_diffs) + entity_mses.append(false_mse) + # if false_mse > true_mse: + # true_diffs = true_diffs - false_diffs ??? + # false_exp_sum += false_exp + + # prob = true_exp / false_exp_sum + # print("prob", prob) + + entity_mses = sorted(entity_mses) + # mse_sum = sum(entity_mses) + # entity_probs = [1 - x/mse_sum for x in entity_mses] + # print("entity_mses", entity_mses) + # print("entity_probs", entity_probs) + true_index = entity_mses.index(true_mse) + # print("true index", true_index) + # print("true prob", entity_probs[true_index]) + + print(true_mse) # print() # TODO: proper backpropagation taking ranking of elements into account ? # TODO backpropagation also for negative examples - true_entity_bp(diffs, sgd=self.sgd_entity) - article_bp(diffs, sgd=self.sgd_article) - print(mse) + true_entity_bp(true_diffs, sgd=self.sgd_entity) + article_bp(doc_diffs, sgd=self.sgd_article) # TODO delete ? @@ -124,11 +178,19 @@ class EL_Model(): return mse + # TODO: expand to more than 2 vectors + def _calculate_consensus(self, vector1, vector2): + if len(vector1) != len(vector2): + raise ValueError("To calculate consenus, both vectors should be of equal length") + + avg = (vector2 + vector1) / 2 + return avg + def _calculate_similarity(self, vector1, vector2): if len(vector1) != len(vector2): raise ValueError("To calculate similarity, both vectors should be of equal length") - diffs = (vector2 - vector1) + diffs = (vector1 - vector2) error_sum = (diffs ** 2).sum() mean_square_error = error_sum / len(vector1) return float(mean_square_error), diffs