diff --git a/examples/pipeline/wiki_entity_linking/training_set_creator.py b/examples/pipeline/wiki_entity_linking/training_set_creator.py index 5d089c620..4ce69e75d 100644 --- a/examples/pipeline/wiki_entity_linking/training_set_creator.py +++ b/examples/pipeline/wiki_entity_linking/training_set_creator.py @@ -5,11 +5,8 @@ import os import re import bz2 import datetime -from os import listdir -from examples.pipeline.wiki_entity_linking import run_el from spacy.gold import GoldParse -from spacy.matcher import PhraseMatcher from . import wikipedia_processor as wp, kb_creator """ @@ -17,7 +14,7 @@ Process Wikipedia interlinks to generate a training dataset for the EL algorithm """ # ENTITY_FILE = "gold_entities.csv" -ENTITY_FILE = "gold_entities_100000.csv" # use this file for faster processing +ENTITY_FILE = "gold_entities_1000000.csv" # use this file for faster processing def create_training(entity_def_input, training_output): @@ -58,7 +55,6 @@ def _process_wikipedia_texts(wp_to_id, training_output, limit=None): if cnt % 1000000 == 0: print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump") clean_line = line.strip().decode("utf-8") - # print(clean_line) if clean_line == "": reading_revision = True @@ -121,7 +117,6 @@ text_regex = re.compile(r'(?<=).*(?= 0 and len(data) % 50 == 0: + print("Read", total_entities, "entities in", len(data), "articles") + fields = line.replace('\n', "").split(sep='|') + article_id = fields[0] + alias = fields[1] + wp_title = fields[2] + start = fields[3] + end = fields[4] - if gold_entities: - gold = GoldParse(doc=article_doc, links=gold_entities) - data.append((article_doc, gold)) + if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles: + if not current_doc or (current_article_id != article_id): + # store the data from the previous article + if gold_entities and current_doc: + gold = GoldParse(doc=current_doc, links=gold_entities) + data.append((current_doc, gold)) + total_entities += len(gold_entities) - except Exception as e: - print("Problem parsing article", article_id) - print(e) - raise e + # parse the new article text + file_name = article_id + ".txt" + try: + with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f: + text = f.read() + current_doc = nlp(text) + for ent in current_doc.ents: + ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent.text + except Exception as e: + print("Problem parsing article", article_id, e) + current_article_id = article_id + gold_entities = list() + + # repeat checking this condition in case an exception was thrown + if current_doc and (current_article_id == article_id): + found_ent = ents_by_offset.get(start + "_" + end, None) + if found_ent: + if found_ent != alias: + skip_articles.add(current_article_id) + else: + gold_entities.append((int(start), int(end), wp_title)) + + print("Read", total_entities, "entities in", len(data), "articles") return data diff --git a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py index b3b3479e2..7b54df527 100644 --- a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py +++ b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py @@ -64,7 +64,8 @@ def run_pipeline(): to_test_pipeline = True # write the NLP object, read back in and test again - test_nlp_io = False + to_write_nlp = True + to_read_nlp = True # STEP 1 : create prior probabilities from WP # run only once ! @@ -133,7 +134,7 @@ def run_pipeline(): if train_pipe: print("STEP 6: training Entity Linking pipe", datetime.datetime.now()) - train_limit = 10 + train_limit = 5 dev_limit = 2 train_data = training_set_creator.read_training(nlp=nlp_2, @@ -166,46 +167,42 @@ def run_pipeline(): ) batchnr += 1 except Exception as e: - print("Error updating batch", e) + print("Error updating batch:", e) + raise(e) - losses['entity_linker'] = losses['entity_linker'] / batchnr - print("Epoch, train loss", itn, round(losses['entity_linker'], 2)) + if batchnr > 0: + losses['entity_linker'] = losses['entity_linker'] / batchnr + print("Epoch, train loss", itn, round(losses['entity_linker'], 2)) dev_data = training_set_creator.read_training(nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit) - print("Dev testing on", len(dev_data), "articles") + print() + print("Dev testing on", len(dev_data), "articles") if len(dev_data) and measure_performance: print() print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now()) print() + acc_random, acc_random_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label = _measure_baselines(dev_data, kb_2) + print("dev acc oracle:", round(acc_oracle, 3), [(x, round(y, 3)) for x, y in acc_oracle_by_label.items()]) + print("dev acc random:", round(acc_random, 3), [(x, round(y, 3)) for x, y in acc_random_by_label.items()]) + print("dev acc prior:", round(acc_prior, 3), [(x, round(y, 3)) for x, y in acc_prior_by_label.items()]) + # print(" measuring accuracy 1-1") el_pipe.context_weight = 1 el_pipe.prior_weight = 1 - dev_acc_1_1, dev_acc_1_1_dict = _measure_accuracy(dev_data, el_pipe) - print("dev acc combo:", round(dev_acc_1_1, 3), [(x, round(y, 3)) for x, y in dev_acc_1_1_dict.items()]) - train_acc_1_1, train_acc_1_1_dict = _measure_accuracy(train_data, el_pipe) - print("train acc combo:", round(train_acc_1_1, 3), [(x, round(y, 3)) for x, y in train_acc_1_1_dict.items()]) - - # baseline using only prior probabilities - el_pipe.context_weight = 0 - el_pipe.prior_weight = 1 - dev_acc_0_1, dev_acc_0_1_dict = _measure_accuracy(dev_data, el_pipe) - print("dev acc prior:", round(dev_acc_0_1, 3), [(x, round(y, 3)) for x, y in dev_acc_0_1_dict.items()]) - train_acc_0_1, train_acc_0_1_dict = _measure_accuracy(train_data, el_pipe) - print("train acc prior:", round(train_acc_0_1, 3), [(x, round(y, 3)) for x, y in train_acc_0_1_dict.items()]) + dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe) + print("dev acc combo:", round(dev_acc_combo, 3), [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]) # using only context el_pipe.context_weight = 1 el_pipe.prior_weight = 0 - dev_acc_1_0, dev_acc_1_0_dict = _measure_accuracy(dev_data, el_pipe) - print("dev acc context:", round(dev_acc_1_0, 3), [(x, round(y, 3)) for x, y in dev_acc_1_0_dict.items()]) - train_acc_1_0, train_acc_1_0_dict = _measure_accuracy(train_data, el_pipe) - print("train acc context:", round(train_acc_1_0, 3), [(x, round(y, 3)) for x, y in train_acc_1_0_dict.items()]) + dev_acc_context, dev_acc_1_0_dict = _measure_accuracy(dev_data, el_pipe) + print("dev acc context:", round(dev_acc_context, 3), [(x, round(y, 3)) for x, y in dev_acc_1_0_dict.items()]) print() # reset for follow-up tests @@ -219,7 +216,7 @@ def run_pipeline(): run_el_toy_example(nlp=nlp_2) print() - if test_nlp_io: + if to_write_nlp: print() print("STEP 9: testing NLP IO", datetime.datetime.now()) print() @@ -229,9 +226,10 @@ def run_pipeline(): print("reading from", NLP_2_DIR) nlp_3 = spacy.load(NLP_2_DIR) - print() - print("running toy example with NLP 2") - run_el_toy_example(nlp=nlp_3) + if to_read_nlp: + print() + print("running toy example with NLP 2") + run_el_toy_example(nlp=nlp_3) print() print("STOP", datetime.datetime.now()) @@ -270,6 +268,80 @@ def _measure_accuracy(data, el_pipe): except Exception as e: print("Error assessing accuracy", e) + acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label) + return acc, acc_by_label + + +def _measure_baselines(data, kb): + random_correct_by_label = dict() + random_incorrect_by_label = dict() + + oracle_correct_by_label = dict() + oracle_incorrect_by_label = dict() + + prior_correct_by_label = dict() + prior_incorrect_by_label = dict() + + docs = [d for d, g in data if len(d) > 0] + golds = [g for d, g in data if len(d) > 0] + + for doc, gold in zip(docs, golds): + try: + correct_entries_per_article = dict() + for entity in gold.links: + start, end, gold_kb = entity + correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb + + for ent in doc.ents: + ent_label = ent.label_ + start = ent.start_char + end = ent.end_char + gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None) + + # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' + if gold_entity is not None: + candidates = kb.get_candidates(ent.text) + oracle_candidate = "" + best_candidate = "" + random_candidate = "" + if candidates: + scores = list() + + for c in candidates: + scores.append(c.prior_prob) + if c.entity_ == gold_entity: + oracle_candidate = c.entity_ + + best_index = scores.index(max(scores)) + best_candidate = candidates[best_index].entity_ + random_candidate = random.choice(candidates).entity_ + + if gold_entity == best_candidate: + prior_correct_by_label[ent_label] = prior_correct_by_label.get(ent_label, 0) + 1 + else: + prior_incorrect_by_label[ent_label] = prior_incorrect_by_label.get(ent_label, 0) + 1 + + if gold_entity == random_candidate: + random_correct_by_label[ent_label] = random_correct_by_label.get(ent_label, 0) + 1 + else: + random_incorrect_by_label[ent_label] = random_incorrect_by_label.get(ent_label, 0) + 1 + + if gold_entity == oracle_candidate: + oracle_correct_by_label[ent_label] = oracle_correct_by_label.get(ent_label, 0) + 1 + else: + oracle_incorrect_by_label[ent_label] = oracle_incorrect_by_label.get(ent_label, 0) + 1 + + except Exception as e: + print("Error assessing accuracy", e) + + acc_prior, acc_prior_by_label = calculate_acc(prior_correct_by_label, prior_incorrect_by_label) + acc_random, acc_random_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label) + acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label) + + return acc_random, acc_random_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label + + +def calculate_acc(correct_by_label, incorrect_by_label): acc_by_label = dict() total_correct = 0 total_incorrect = 0 @@ -303,18 +375,25 @@ def run_el_toy_example(nlp): "The main character in Doug's novel is the man Arthur Dent, " \ "but Douglas doesn't write about George Washington or Homer Simpson." doc = nlp(text) - + print(text) for ent in doc.ents: print("ent", ent.text, ent.label_, ent.kb_id_) - print() - # Q4426480 is her husband, Q3568763 her tutor - text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine."\ - "Ada Lovelace loved her husband William King dearly. " \ - "Ada Lovelace was tutored by her favorite physics tutor William King." + # Q4426480 is her husband + text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine. "\ + "She loved her husband William King dearly. " doc = nlp(text) + print(text) + for ent in doc.ents: + print("ent", ent.text, ent.label_, ent.kb_id_) + print() + # Q3568763 is her tutor + text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine. "\ + "She was tutored by her favorite physics tutor William King." + doc = nlp(text) + print(text) for ent in doc.ents: print("ent", ent.text, ent.label_, ent.kb_id_)