diff --git a/examples/pipeline/wiki_entity_linking/train_el.py b/examples/pipeline/wiki_entity_linking/train_el.py index ac8cae4a4..ea42f9ab6 100644 --- a/examples/pipeline/wiki_entity_linking/train_el.py +++ b/examples/pipeline/wiki_entity_linking/train_el.py @@ -11,11 +11,11 @@ from thinc.neural._classes.convolution import ExtractWindow from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator -from spacy._ml import SpacyVectors, create_default_optimizer, zero_init, logistic, Tok2Vec, cosine +from spacy._ml import SpacyVectors, create_default_optimizer, zero_init, cosine from thinc.api import chain, concatenate, flatten_add_lengths, clone, with_flatten -from thinc.v2v import Model, Maxout, Affine, ReLu -from thinc.t2v import Pooling, mean_pool, sum_pool +from thinc.v2v import Model, Maxout, Affine +from thinc.t2v import Pooling, mean_pool from thinc.t2t import ParametricAttention from thinc.misc import Residual from thinc.misc import LayerNorm as LN @@ -30,24 +30,21 @@ from spacy.tokens import Doc class EL_Model: PRINT_INSPECT = False - PRINT_TRAIN = True + PRINT_BATCH_LOSS = False EPS = 0.0000000005 - CUTOFF = 0.5 BATCH_SIZE = 5 - # UPSAMPLE = True DOC_CUTOFF = 300 # number of characters from the doc context INPUT_DIM = 300 # dimension of pre-trained vectors HIDDEN_1_WIDTH = 32 - # HIDDEN_2_WIDTH = 32 # 6 DESC_WIDTH = 64 - ARTICLE_WIDTH = 64 + ARTICLE_WIDTH = 128 SENT_WIDTH = 64 DROP = 0.1 - LEARN_RATE = 0.0001 + LEARN_RATE = 0.001 EPOCHS = 10 L2 = 1e-6 @@ -61,13 +58,10 @@ class EL_Model: self._build_cnn(embed_width=self.INPUT_DIM, desc_width=self.DESC_WIDTH, article_width=self.ARTICLE_WIDTH, - sent_width=self.SENT_WIDTH, hidden_1_width=self.HIDDEN_1_WIDTH) + sent_width=self.SENT_WIDTH, + hidden_1_width=self.HIDDEN_1_WIDTH) def train_model(self, training_dir, entity_descr_output, trainlimit=None, devlimit=None, to_print=True): - # raise errors instead of runtime warnings in case of int/float overflow - # (not sure if we need this. set L2 to 0 because it throws an error otherwsise) - # np.seterr(all='raise') - # alternative: np.seterr(divide="raise", over="warn", under="ignore", invalid="raise") train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts = \ @@ -101,21 +95,6 @@ class EL_Model: train_pos_count = len(train_pos_entities) train_neg_count = len(train_neg_entities) - # if self.UPSAMPLE: - # if to_print: - # print() - # print("Upsampling, original training instances pos/neg:", train_pos_count, train_neg_count) - # - # # upsample positives to 50-50 distribution - # while train_pos_count < train_neg_count: - # train_ent.append(random.choice(train_pos_entities)) - # train_pos_count += 1 - # - # upsample negatives to 50-50 distribution - # while train_neg_count < train_pos_count: - # train_ent.append(random.choice(train_neg_entities)) - # train_neg_count += 1 - self._begin_training() if to_print: @@ -126,24 +105,25 @@ class EL_Model: print("Dev test on", len(dev_clusters), "entity clusters in", len(dev_art_texts), "articles") print("Dev instances pos/neg:", dev_pos_count, dev_neg_count) print() - print(" CUTOFF", self.CUTOFF) print(" DOC_CUTOFF", self.DOC_CUTOFF) print(" INPUT_DIM", self.INPUT_DIM) - # print(" HIDDEN_1_WIDTH", self.HIDDEN_1_WIDTH) + print(" HIDDEN_1_WIDTH", self.HIDDEN_1_WIDTH) print(" DESC_WIDTH", self.DESC_WIDTH) print(" ARTICLE_WIDTH", self.ARTICLE_WIDTH) print(" SENT_WIDTH", self.SENT_WIDTH) - # print(" HIDDEN_2_WIDTH", self.HIDDEN_2_WIDTH) print(" DROP", self.DROP) print(" LEARNING RATE", self.LEARN_RATE) - print(" UPSAMPLE", self.UPSAMPLE) + print(" BATCH SIZE", self.BATCH_SIZE) print() - self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts, - print_string="dev_random", calc_random=True) + dev_random = self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts, + calc_random=True) + print("acc", "dev_random", round(dev_random, 2)) - self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts, - print_string="dev_pre", avg=True) + dev_pre = self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts, + avg=True) + print("acc", "dev_pre", round(dev_pre, 2)) + print() processed = 0 for i in range(self.EPOCHS): @@ -163,45 +143,58 @@ class EL_Model: start = start + self.BATCH_SIZE stop = min(stop + self.BATCH_SIZE, len(train_clusters)) - if self.PRINT_TRAIN: - print() - self._test_dev(train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts, - print_string="train_inter_epoch " + str(i), avg=True) + train_acc = self._test_dev(train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts, avg=True) + dev_acc = self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts, avg=True) - self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts, - print_string="dev_inter_epoch " + str(i), avg=True) + print(i, "acc train/dev", round(train_acc, 2), round(dev_acc, 2)) if to_print: print() print("Trained on", processed, "entity clusters across", self.EPOCHS, "epochs") - def _test_dev(self, entity_clusters, golds, descs, arts, art_texts, sents, sent_texts, - print_string, avg=True, calc_random=False): - + def _test_dev(self, entity_clusters, golds, descs, arts, art_texts, sents, sent_texts, avg=True, calc_random=False): correct = 0 incorrect = 0 - for cluster, entities in entity_clusters.items(): - correct_entities = [e for e in entities if golds[e]] - incorrect_entities = [e for e in entities if not golds[e]] - assert len(correct_entities) == 1 + if calc_random: + for cluster, entities in entity_clusters.items(): + correct_entities = [e for e in entities if golds[e]] + assert len(correct_entities) == 1 - entities = list(entities) - shuffle(entities) + entities = list(entities) + shuffle(entities) - if calc_random: - predicted_entity = random.choice(entities) - if predicted_entity in correct_entities: - correct += 1 - else: - incorrect += 1 + if calc_random: + predicted_entity = random.choice(entities) + if predicted_entity in correct_entities: + correct += 1 + else: + incorrect += 1 + + else: + all_clusters = list() + arts_list = list() + sents_list = list() + + for cluster in entity_clusters.keys(): + all_clusters.append(cluster) + arts_list.append(art_texts[arts[cluster]]) + sents_list.append(sent_texts[sents[cluster]]) + + art_docs = list(self.nlp.pipe(arts_list)) + sent_docs = list(self.nlp.pipe(sents_list)) + + for i, cluster in enumerate(all_clusters): + entities = entity_clusters[cluster] + correct_entities = [e for e in entities if golds[e]] + assert len(correct_entities) == 1 + + entities = list(entities) + shuffle(entities) - else: desc_docs = self.nlp.pipe([descs[e] for e in entities]) - # article_texts = [art_texts[arts[e]] for e in entities] - - sent_doc = self.nlp(sent_texts[sents[cluster]]) - article_doc = self.nlp(art_texts[arts[cluster]]) + sent_doc = sent_docs[i] + article_doc = art_docs[i] predicted_index = self._predict(article_doc=article_doc, sent_doc=sent_doc, desc_docs=desc_docs, avg=avg) @@ -211,52 +204,56 @@ class EL_Model: incorrect += 1 if correct == incorrect == 0: - print("acc", print_string, "NA") return 0 acc = correct / (correct + incorrect) - print("acc", print_string, round(acc, 2)) return acc def _predict(self, article_doc, sent_doc, desc_docs, avg=True, apply_threshold=True): + # print() + # print("predicting article") + if avg: with self.article_encoder.use_params(self.sgd_article.averages) \ and self.desc_encoder.use_params(self.sgd_desc.averages)\ - and self.sent_encoder.use_params(self.sgd_sent.averages): - # doc_encoding = self.article_encoder(article_doc) + and self.sent_encoder.use_params(self.sgd_sent.averages)\ + and self.cont_encoder.use_params(self.sgd_cont.averages): desc_encodings = self.desc_encoder(desc_docs) + doc_encoding = self.article_encoder([article_doc]) sent_encoding = self.sent_encoder([sent_doc]) else: - # doc_encodings = self.article_encoder(article_docs) desc_encodings = self.desc_encoder(desc_docs) + doc_encoding = self.article_encoder([article_doc]) sent_encoding = self.sent_encoder([sent_doc]) - sent_enc = np.transpose(sent_encoding) + # print("desc_encodings", desc_encodings) + # print("doc_encoding", doc_encoding) + # print("sent_encoding", sent_encoding) + concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])] + # print("concat_encoding", concat_encoding) + + cont_encodings = self.cont_encoder(np.asarray([concat_encoding[0]])) + # print("cont_encodings", cont_encodings) + context_enc = np.transpose(cont_encodings) + # print("context_enc", context_enc) + highest_sim = -5 best_i = -1 for i, desc_enc in enumerate(desc_encodings): - sim = cosine(desc_enc, sent_enc) + sim = cosine(desc_enc, context_enc) if sim >= highest_sim: best_i = i highest_sim = sim return best_i - def _predict_random(self, entities, apply_threshold=True): - if not apply_threshold: - return [float(random.uniform(0, 1)) for _ in entities] - else: - return [float(1.0) if random.uniform(0, 1) > self.CUTOFF else float(0.0) for _ in entities] - def _build_cnn(self, embed_width, desc_width, article_width, sent_width, hidden_1_width): - with Model.define_operators({">>": chain, "|": concatenate, "**": clone}): - self.desc_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_1_width, - end_width=desc_width) - self.article_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_1_width, - end_width=article_width) - self.sent_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_1_width, - end_width=sent_width) + self.desc_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_1_width, end_width=desc_width) + self.cont_encoder = self._context_encoder(embed_width=embed_width, article_width=article_width, + sent_width=sent_width, hidden_width=hidden_1_width, + end_width=desc_width) + # def _encoder(self, width): # tok2vec = Tok2Vec(width=width, embed_size=2000, pretrained_vectors=self.nlp.vocab.vectors.name, cnn_maxout_pieces=3, @@ -264,12 +261,19 @@ class EL_Model: # # return tok2vec >> flatten_add_lengths >> Pooling(mean_pool) + def _context_encoder(self, embed_width, article_width, sent_width, hidden_width, end_width): + self.article_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_width, end_width=article_width) + self.sent_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_width, end_width=sent_width) + + model = Affine(end_width, article_width+sent_width, drop_factor=0.0) + return model + @staticmethod def _encoder(in_width, hidden_with, end_width): conv_depth = 2 cnn_maxout_pieces = 3 - with Model.define_operators({">>": chain}): + with Model.define_operators({">>": chain, "**": clone}): convolution = Residual((ExtractWindow(nW=1) >> LN(Maxout(hidden_with, hidden_with * 3, pieces=cnn_maxout_pieces)))) @@ -295,62 +299,75 @@ class EL_Model: self.sgd_sent.learn_rate = self.LEARN_RATE self.sgd_sent.L2 = self.L2 + self.sgd_cont = create_default_optimizer(self.cont_encoder.ops) + self.sgd_cont.learn_rate = self.LEARN_RATE + self.sgd_cont.L2 = self.L2 + self.sgd_desc = create_default_optimizer(self.desc_encoder.ops) self.sgd_desc.learn_rate = self.LEARN_RATE self.sgd_desc.L2 = self.L2 - # self.sgd = create_default_optimizer(self.model.ops) - # self.sgd.learn_rate = self.LEARN_RATE - # self.sgd.L2 = self.L2 - @staticmethod def get_loss(predictions, golds): loss, gradients = get_cossim_loss(predictions, golds) return loss, gradients def update(self, entity_clusters, golds, descs, art_texts, arts, sent_texts, sents): + all_clusters = list(entity_clusters.keys()) + + arts_list = list() + sents_list = list() + descs_list = list() + for cluster, entities in entity_clusters.items(): - correct_entities = [e for e in entities if golds[e]] - incorrect_entities = [e for e in entities if not golds[e]] - - assert len(correct_entities) == 1 - entities = list(entities) - shuffle(entities) - - # article_text = art_texts[arts[cluster]] - cluster_sent = sent_texts[sents[cluster]] - - # art_docs = self.nlp.pipe(article_text) - sent_doc = self.nlp(cluster_sent) - + art = art_texts[arts[cluster]] + sent = sent_texts[sents[cluster]] for e in entities: + # TODO: more appropriate loss for the whole cluster (currently only pos entities) if golds[e]: - # TODO: more appropriate loss for the whole cluster (currently only pos entities) - # TODO: speed up - desc_doc = self.nlp(descs[e]) + arts_list.append(art) + sents_list.append(sent) + descs_list.append(descs[e]) - # doc_encodings, bp_doc = self.article_encoder.begin_update(art_docs, drop=self.DROP) - sent_encodings, bp_sent = self.sent_encoder.begin_update([sent_doc], drop=self.DROP) - desc_encodings, bp_desc = self.desc_encoder.begin_update([desc_doc], drop=self.DROP) + desc_docs = self.nlp.pipe(descs_list) + desc_encodings, bp_desc = self.desc_encoder.begin_update(desc_docs, drop=self.DROP) - sent_encoding = sent_encodings[0] - desc_encoding = desc_encodings[0] + art_docs = self.nlp.pipe(arts_list) + sent_docs = self.nlp.pipe(sents_list) - sent_enc = self.sent_encoder.ops.asarray([sent_encoding]) - desc_enc = self.sent_encoder.ops.asarray([desc_encoding]) + doc_encodings, bp_doc = self.article_encoder.begin_update(art_docs, drop=self.DROP) + sent_encodings, bp_sent = self.sent_encoder.begin_update(sent_docs, drop=self.DROP) - # print("sent_encoding", type(sent_encoding), sent_encoding) - # print("desc_encoding", type(desc_encoding), desc_encoding) - # print("getting los for entity", e) + concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in + range(len(all_clusters))] + cont_encodings, bp_cont = self.cont_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP) - loss, gradient = self.get_loss(sent_enc, desc_enc) + # print("sent_encodings", type(sent_encodings), sent_encodings) + # print("desc_encodings", type(desc_encodings), desc_encodings) + # print("doc_encodings", type(doc_encodings), doc_encodings) + # print("getting los for", len(arts_list), "entities") - # print("gradient", gradient) - # print("loss", loss) + loss, gradient = self.get_loss(cont_encodings, desc_encodings) - bp_sent(gradient, sgd=self.sgd_sent) - # bp_desc(desc_gradients, sgd=self.sgd_desc) TODO - # print() + # print("gradient", gradient) + if self.PRINT_BATCH_LOSS: + print("batch loss", loss) + + context_gradient = bp_cont(gradient, sgd=self.sgd_cont) + + # gradient : concat (doc+sent) vs. desc + sent_start = self.ARTICLE_WIDTH + sent_gradients = list() + doc_gradients = list() + for x in context_gradient: + doc_gradients.append(list(x[0:sent_start])) + sent_gradients.append(list(x[sent_start:])) + + # print("doc_gradients", doc_gradients) + # print("sent_gradients", sent_gradients) + + bp_doc(doc_gradients, sgd=self.sgd_article) + bp_sent(sent_gradients, sgd=self.sgd_sent) def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print): id_to_descr = kb_creator._get_id_to_description(entity_descr_output) diff --git a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py index a24ff30c5..25c1e4721 100644 --- a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py +++ b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py @@ -111,7 +111,7 @@ if __name__ == "__main__": print("STEP 6: training", datetime.datetime.now()) my_nlp = spacy.load('en_core_web_md') trainer = EL_Model(kb=my_kb, nlp=my_nlp) - trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=1000, devlimit=100) + trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=5000, devlimit=100) print() # STEP 7: apply the EL algorithm on the dev dataset