From 9e88763dab895d7ee86a21d78c0e2c950e8d6850 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Mon, 3 Jun 2019 08:04:49 +0200 Subject: [PATCH] 60% acc run --- .../pipeline/wiki_entity_linking/train_el.py | 159 ++++++++---------- .../wiki_entity_linking/wiki_nel_pipeline.py | 3 +- 2 files changed, 74 insertions(+), 88 deletions(-) diff --git a/examples/pipeline/wiki_entity_linking/train_el.py b/examples/pipeline/wiki_entity_linking/train_el.py index ba8a6a6c9..a2db2dc95 100644 --- a/examples/pipeline/wiki_entity_linking/train_el.py +++ b/examples/pipeline/wiki_entity_linking/train_el.py @@ -23,7 +23,6 @@ from thinc.misc import LayerNorm as LN # from spacy.cli.pretrain import get_cossim_loss from spacy.matcher import PhraseMatcher -from spacy.tokens import Doc """ TODO: this code needs to be implemented in pipes.pyx""" @@ -46,7 +45,7 @@ class EL_Model: DROP = 0.1 LEARN_RATE = 0.001 - EPOCHS = 10 + EPOCHS = 20 L2 = 1e-6 name = "entity_linker" @@ -211,9 +210,6 @@ class EL_Model: return acc def _predict(self, article_doc, sent_doc, desc_docs, avg=True, apply_threshold=True): - # print() - # print("predicting article") - if avg: with self.article_encoder.use_params(self.sgd_article.averages) \ and self.desc_encoder.use_params(self.sgd_desc.averages)\ @@ -228,16 +224,10 @@ class EL_Model: doc_encoding = self.article_encoder([article_doc]) sent_encoding = self.sent_encoder([sent_doc]) - # print("desc_encodings", desc_encodings) - # print("doc_encoding", doc_encoding) - # print("sent_encoding", sent_encoding) concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])] - # print("concat_encoding", concat_encoding) cont_encodings = self.cont_encoder(np.asarray([concat_encoding[0]])) - # print("cont_encodings", cont_encodings) context_enc = np.transpose(cont_encodings) - # print("context_enc", context_enc) highest_sim = -5 best_i = -1 @@ -353,11 +343,11 @@ class EL_Model: sents_list.append(sent) descs_list.append(descs[e]) targets.append([1]) - else: - arts_list.append(art) - sents_list.append(sent) - descs_list.append(descs[e]) - targets.append([-1]) + # else: + # arts_list.append(art) + # sents_list.append(sent) + # descs_list.append(descs[e]) + # targets.append([-1]) desc_docs = self.nlp.pipe(descs_list) desc_encodings, bp_desc = self.desc_encoder.begin_update(desc_docs, drop=self.DROP) @@ -372,18 +362,17 @@ class EL_Model: range(len(targets))] cont_encodings, bp_cont = self.cont_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP) - # print("sent_encodings", type(sent_encodings), sent_encodings) - # print("desc_encodings", type(desc_encodings), desc_encodings) - # print("doc_encodings", type(doc_encodings), doc_encodings) - # print("getting los for", len(arts_list), "entities") + loss, cont_gradient = self.get_loss(cont_encodings, desc_encodings, targets) - loss, gradient = self.get_loss(cont_encodings, desc_encodings, targets) + # loss, desc_gradient = self.get_loss(desc_encodings, cont_encodings, targets) + # cont_gradient = cont_gradient / 2 + # desc_gradient = desc_gradient / 2 + # bp_desc(desc_gradient, sgd=self.sgd_desc) - # print("gradient", gradient) if self.PRINT_BATCH_LOSS: print("batch loss", loss) - context_gradient = bp_cont(gradient, sgd=self.sgd_cont) + context_gradient = bp_cont(cont_gradient, sgd=self.sgd_cont) # gradient : concat (doc+sent) vs. desc sent_start = self.ARTICLE_WIDTH @@ -393,9 +382,6 @@ class EL_Model: doc_gradients.append(list(x[0:sent_start])) sent_gradients.append(list(x[sent_start:])) - # print("doc_gradients", doc_gradients) - # print("sent_gradients", sent_gradients) - bp_doc(doc_gradients, sgd=self.sgd_article) bp_sent(sent_gradients, sgd=self.sgd_sent) @@ -426,74 +412,75 @@ class EL_Model: article_id = f.replace(".txt", "") if cnt % 500 == 0 and to_print: print(datetime.datetime.now(), "processed", cnt, "files in the training dataset") - cnt += 1 - # parse the article text - with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file: - text = file.read() - article_doc = self.nlp(text) - truncated_text = text[0:min(self.DOC_CUTOFF, len(text))] - text_by_article[article_id] = truncated_text + try: + # parse the article text + with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file: + text = file.read() + article_doc = self.nlp(text) + truncated_text = text[0:min(self.DOC_CUTOFF, len(text))] + text_by_article[article_id] = truncated_text - # process all positive and negative entities, collect all relevant mentions in this article - for mention, entity_pos in correct_entries[article_id].items(): - cluster = article_id + "_" + mention - descr = id_to_descr.get(entity_pos) - entities = set() - if descr: - entity = "E_" + str(next_entity_nr) + "_" + cluster - next_entity_nr += 1 - gold_by_entity[entity] = 1 - desc_by_entity[entity] = descr - entities.add(entity) + # process all positive and negative entities, collect all relevant mentions in this article + for mention, entity_pos in correct_entries[article_id].items(): + cluster = article_id + "_" + mention + descr = id_to_descr.get(entity_pos) + entities = set() + if descr: + entity = "E_" + str(next_entity_nr) + "_" + cluster + next_entity_nr += 1 + gold_by_entity[entity] = 1 + desc_by_entity[entity] = descr + entities.add(entity) - entity_negs = incorrect_entries[article_id][mention] - for entity_neg in entity_negs: - descr = id_to_descr.get(entity_neg) - if descr: - entity = "E_" + str(next_entity_nr) + "_" + cluster - next_entity_nr += 1 - gold_by_entity[entity] = 0 - desc_by_entity[entity] = descr - entities.add(entity) + entity_negs = incorrect_entries[article_id][mention] + for entity_neg in entity_negs: + descr = id_to_descr.get(entity_neg) + if descr: + entity = "E_" + str(next_entity_nr) + "_" + cluster + next_entity_nr += 1 + gold_by_entity[entity] = 0 + desc_by_entity[entity] = descr + entities.add(entity) - found_matches = 0 - if len(entities) > 1: - entities_by_cluster[cluster] = entities + found_matches = 0 + if len(entities) > 1: + entities_by_cluster[cluster] = entities - # find all matches in the doc for the mentions - # TODO: fix this - doesn't look like all entities are found - matcher = PhraseMatcher(self.nlp.vocab) - patterns = list(self.nlp.tokenizer.pipe([mention])) + # find all matches in the doc for the mentions + # TODO: fix this - doesn't look like all entities are found + matcher = PhraseMatcher(self.nlp.vocab) + patterns = list(self.nlp.tokenizer.pipe([mention])) - matcher.add("TerminologyList", None, *patterns) - matches = matcher(article_doc) + matcher.add("TerminologyList", None, *patterns) + matches = matcher(article_doc) + # store sentences + for match_id, start, end in matches: + span = article_doc[start:end] + if mention == span.text: + found_matches += 1 + sent_text = span.sent.text + sent_nr = sentence_by_text.get(sent_text, None) + if sent_nr is None: + sent_nr = "S_" + str(next_sent_nr) + article_id + next_sent_nr += 1 + text_by_sentence[sent_nr] = sent_text + sentence_by_text[sent_text] = sent_nr + article_by_cluster[cluster] = article_id + sentence_by_cluster[cluster] = sent_nr - # store sentences - for match_id, start, end in matches: - found_matches += 1 - span = article_doc[start:end] - assert mention == span.text - sent_text = span.sent.text - sent_nr = sentence_by_text.get(sent_text, None) - if sent_nr is None: - sent_nr = "S_" + str(next_sent_nr) + article_id - next_sent_nr += 1 - text_by_sentence[sent_nr] = sent_text - sentence_by_text[sent_text] = sent_nr - article_by_cluster[cluster] = article_id - sentence_by_cluster[cluster] = sent_nr - - if found_matches == 0: - # TODO print("Could not find neg instances or sentence matches for", mention, "in", article_id) - entities_by_cluster.pop(cluster, None) - article_by_cluster.pop(cluster, None) - sentence_by_cluster.pop(cluster, None) - for entity in entities: - gold_by_entity.pop(entity, None) - desc_by_entity.pop(entity, None) - + if found_matches == 0: + # print("Could not find neg instances or sentence matches for", mention, "in", article_id) + entities_by_cluster.pop(cluster, None) + article_by_cluster.pop(cluster, None) + sentence_by_cluster.pop(cluster, None) + for entity in entities: + gold_by_entity.pop(entity, None) + desc_by_entity.pop(entity, None) + cnt += 1 + except: + print("Problem parsing article", article_id) if to_print: print() diff --git a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py index a24ff30c5..2ebf9973e 100644 --- a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py +++ b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py @@ -111,7 +111,7 @@ if __name__ == "__main__": print("STEP 6: training", datetime.datetime.now()) my_nlp = spacy.load('en_core_web_md') trainer = EL_Model(kb=my_kb, nlp=my_nlp) - trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=1000, devlimit=100) + trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=10000, devlimit=500) print() # STEP 7: apply the EL algorithm on the dev dataset @@ -120,7 +120,6 @@ if __name__ == "__main__": run_el.run_el_dev(kb=my_kb, nlp=my_nlp, training_dir=TRAINING_DIR, limit=2000) print() - # TODO coreference resolution # add_coref()