mirror of https://github.com/explosion/spaCy.git
baseline performances: oracle KB, random and prior prob
This commit is contained in:
parent
24db1392b9
commit
6332af40de
|
@ -5,11 +5,8 @@ import os
|
|||
import re
|
||||
import bz2
|
||||
import datetime
|
||||
from os import listdir
|
||||
|
||||
from examples.pipeline.wiki_entity_linking import run_el
|
||||
from spacy.gold import GoldParse
|
||||
from spacy.matcher import PhraseMatcher
|
||||
from . import wikipedia_processor as wp, kb_creator
|
||||
|
||||
"""
|
||||
|
@ -17,7 +14,7 @@ Process Wikipedia interlinks to generate a training dataset for the EL algorithm
|
|||
"""
|
||||
|
||||
# ENTITY_FILE = "gold_entities.csv"
|
||||
ENTITY_FILE = "gold_entities_100000.csv" # use this file for faster processing
|
||||
ENTITY_FILE = "gold_entities_1000000.csv" # use this file for faster processing
|
||||
|
||||
|
||||
def create_training(entity_def_input, training_output):
|
||||
|
@ -58,7 +55,6 @@ def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
|
|||
if cnt % 1000000 == 0:
|
||||
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
|
||||
clean_line = line.strip().decode("utf-8")
|
||||
# print(clean_line)
|
||||
|
||||
if clean_line == "<revision>":
|
||||
reading_revision = True
|
||||
|
@ -121,7 +117,6 @@ text_regex = re.compile(r'(?<=<text xml:space=\"preserve\">).*(?=</text)')
|
|||
|
||||
def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text, training_output):
|
||||
found_entities = False
|
||||
# print("Processing", article_id, article_title)
|
||||
|
||||
# ignore meta Wikipedia pages
|
||||
if article_title.startswith("Wikipedia:"):
|
||||
|
@ -134,13 +129,8 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
|||
if text.startswith("#REDIRECT"):
|
||||
return
|
||||
|
||||
# print()
|
||||
# print(text)
|
||||
|
||||
# get the raw text without markup etc, keeping only interwiki links
|
||||
clean_text = _get_clean_wp_text(text)
|
||||
# print()
|
||||
# print(clean_text)
|
||||
|
||||
# read the text char by char to get the right offsets of the interwiki links
|
||||
final_text = ""
|
||||
|
@ -295,68 +285,62 @@ def is_dev(article_id):
|
|||
return article_id.endswith("3")
|
||||
|
||||
|
||||
def read_training_entities(training_output, dev, limit):
|
||||
entityfile_loc = training_output + "/" + ENTITY_FILE
|
||||
entries_per_article = dict()
|
||||
article_ids = set()
|
||||
|
||||
with open(entityfile_loc, mode='r', encoding='utf8') as file:
|
||||
for line in file:
|
||||
if not limit or len(article_ids) < limit:
|
||||
fields = line.replace('\n', "").split(sep='|')
|
||||
article_id = fields[0]
|
||||
if dev == is_dev(article_id) and article_id != "article_id":
|
||||
article_ids.add(article_id)
|
||||
|
||||
alias = fields[1]
|
||||
wp_title = fields[2]
|
||||
start = fields[3]
|
||||
end = fields[4]
|
||||
|
||||
entries_by_offset = entries_per_article.get(article_id, dict())
|
||||
entries_by_offset[start + "-" + end] = (alias, wp_title)
|
||||
|
||||
entries_per_article[article_id] = entries_by_offset
|
||||
|
||||
return entries_per_article
|
||||
|
||||
|
||||
def read_training(nlp, training_dir, dev, limit):
|
||||
# This method provides training examples that correspond to the entity annotations found by the nlp object
|
||||
|
||||
print("reading training entities")
|
||||
entries_per_article = read_training_entities(training_output=training_dir, dev=dev, limit=limit)
|
||||
print("done reading training entities")
|
||||
|
||||
entityfile_loc = training_dir + "/" + ENTITY_FILE
|
||||
data = []
|
||||
for article_id, entries_by_offset in entries_per_article.items():
|
||||
file_name = article_id + ".txt"
|
||||
try:
|
||||
# parse the article text
|
||||
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as file:
|
||||
text = file.read()
|
||||
article_doc = nlp(text)
|
||||
|
||||
gold_entities = list()
|
||||
for ent in article_doc.ents:
|
||||
start = ent.start_char
|
||||
end = ent.end_char
|
||||
# we assume the data is written sequentially
|
||||
current_article_id = None
|
||||
current_doc = None
|
||||
gold_entities = list()
|
||||
ents_by_offset = dict()
|
||||
skip_articles = set()
|
||||
total_entities = 0
|
||||
|
||||
entity_tuple = entries_by_offset.get(str(start) + "-" + str(end), None)
|
||||
if entity_tuple:
|
||||
alias, wp_title = entity_tuple
|
||||
if ent.text != alias:
|
||||
print("Non-matching entity in", article_id, start, end)
|
||||
else:
|
||||
gold_entities.append((start, end, wp_title))
|
||||
with open(entityfile_loc, mode='r', encoding='utf8') as file:
|
||||
for line in file:
|
||||
if not limit or len(data) < limit:
|
||||
if len(data) > 0 and len(data) % 50 == 0:
|
||||
print("Read", total_entities, "entities in", len(data), "articles")
|
||||
fields = line.replace('\n', "").split(sep='|')
|
||||
article_id = fields[0]
|
||||
alias = fields[1]
|
||||
wp_title = fields[2]
|
||||
start = fields[3]
|
||||
end = fields[4]
|
||||
|
||||
if gold_entities:
|
||||
gold = GoldParse(doc=article_doc, links=gold_entities)
|
||||
data.append((article_doc, gold))
|
||||
if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles:
|
||||
if not current_doc or (current_article_id != article_id):
|
||||
# store the data from the previous article
|
||||
if gold_entities and current_doc:
|
||||
gold = GoldParse(doc=current_doc, links=gold_entities)
|
||||
data.append((current_doc, gold))
|
||||
total_entities += len(gold_entities)
|
||||
|
||||
except Exception as e:
|
||||
print("Problem parsing article", article_id)
|
||||
print(e)
|
||||
raise e
|
||||
# parse the new article text
|
||||
file_name = article_id + ".txt"
|
||||
try:
|
||||
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f:
|
||||
text = f.read()
|
||||
current_doc = nlp(text)
|
||||
for ent in current_doc.ents:
|
||||
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent.text
|
||||
except Exception as e:
|
||||
print("Problem parsing article", article_id, e)
|
||||
|
||||
current_article_id = article_id
|
||||
gold_entities = list()
|
||||
|
||||
# repeat checking this condition in case an exception was thrown
|
||||
if current_doc and (current_article_id == article_id):
|
||||
found_ent = ents_by_offset.get(start + "_" + end, None)
|
||||
if found_ent:
|
||||
if found_ent != alias:
|
||||
skip_articles.add(current_article_id)
|
||||
else:
|
||||
gold_entities.append((int(start), int(end), wp_title))
|
||||
|
||||
print("Read", total_entities, "entities in", len(data), "articles")
|
||||
return data
|
||||
|
|
|
@ -64,7 +64,8 @@ def run_pipeline():
|
|||
to_test_pipeline = True
|
||||
|
||||
# write the NLP object, read back in and test again
|
||||
test_nlp_io = False
|
||||
to_write_nlp = True
|
||||
to_read_nlp = True
|
||||
|
||||
# STEP 1 : create prior probabilities from WP
|
||||
# run only once !
|
||||
|
@ -133,7 +134,7 @@ def run_pipeline():
|
|||
|
||||
if train_pipe:
|
||||
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
||||
train_limit = 10
|
||||
train_limit = 5
|
||||
dev_limit = 2
|
||||
|
||||
train_data = training_set_creator.read_training(nlp=nlp_2,
|
||||
|
@ -166,46 +167,42 @@ def run_pipeline():
|
|||
)
|
||||
batchnr += 1
|
||||
except Exception as e:
|
||||
print("Error updating batch", e)
|
||||
print("Error updating batch:", e)
|
||||
raise(e)
|
||||
|
||||
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
||||
print("Epoch, train loss", itn, round(losses['entity_linker'], 2))
|
||||
if batchnr > 0:
|
||||
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
||||
print("Epoch, train loss", itn, round(losses['entity_linker'], 2))
|
||||
|
||||
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
||||
training_dir=TRAINING_DIR,
|
||||
dev=True,
|
||||
limit=dev_limit)
|
||||
print("Dev testing on", len(dev_data), "articles")
|
||||
|
||||
print()
|
||||
print("Dev testing on", len(dev_data), "articles")
|
||||
|
||||
if len(dev_data) and measure_performance:
|
||||
print()
|
||||
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
|
||||
print()
|
||||
|
||||
acc_random, acc_random_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label = _measure_baselines(dev_data, kb_2)
|
||||
print("dev acc oracle:", round(acc_oracle, 3), [(x, round(y, 3)) for x, y in acc_oracle_by_label.items()])
|
||||
print("dev acc random:", round(acc_random, 3), [(x, round(y, 3)) for x, y in acc_random_by_label.items()])
|
||||
print("dev acc prior:", round(acc_prior, 3), [(x, round(y, 3)) for x, y in acc_prior_by_label.items()])
|
||||
|
||||
# print(" measuring accuracy 1-1")
|
||||
el_pipe.context_weight = 1
|
||||
el_pipe.prior_weight = 1
|
||||
dev_acc_1_1, dev_acc_1_1_dict = _measure_accuracy(dev_data, el_pipe)
|
||||
print("dev acc combo:", round(dev_acc_1_1, 3), [(x, round(y, 3)) for x, y in dev_acc_1_1_dict.items()])
|
||||
train_acc_1_1, train_acc_1_1_dict = _measure_accuracy(train_data, el_pipe)
|
||||
print("train acc combo:", round(train_acc_1_1, 3), [(x, round(y, 3)) for x, y in train_acc_1_1_dict.items()])
|
||||
|
||||
# baseline using only prior probabilities
|
||||
el_pipe.context_weight = 0
|
||||
el_pipe.prior_weight = 1
|
||||
dev_acc_0_1, dev_acc_0_1_dict = _measure_accuracy(dev_data, el_pipe)
|
||||
print("dev acc prior:", round(dev_acc_0_1, 3), [(x, round(y, 3)) for x, y in dev_acc_0_1_dict.items()])
|
||||
train_acc_0_1, train_acc_0_1_dict = _measure_accuracy(train_data, el_pipe)
|
||||
print("train acc prior:", round(train_acc_0_1, 3), [(x, round(y, 3)) for x, y in train_acc_0_1_dict.items()])
|
||||
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe)
|
||||
print("dev acc combo:", round(dev_acc_combo, 3), [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
||||
|
||||
# using only context
|
||||
el_pipe.context_weight = 1
|
||||
el_pipe.prior_weight = 0
|
||||
dev_acc_1_0, dev_acc_1_0_dict = _measure_accuracy(dev_data, el_pipe)
|
||||
print("dev acc context:", round(dev_acc_1_0, 3), [(x, round(y, 3)) for x, y in dev_acc_1_0_dict.items()])
|
||||
train_acc_1_0, train_acc_1_0_dict = _measure_accuracy(train_data, el_pipe)
|
||||
print("train acc context:", round(train_acc_1_0, 3), [(x, round(y, 3)) for x, y in train_acc_1_0_dict.items()])
|
||||
dev_acc_context, dev_acc_1_0_dict = _measure_accuracy(dev_data, el_pipe)
|
||||
print("dev acc context:", round(dev_acc_context, 3), [(x, round(y, 3)) for x, y in dev_acc_1_0_dict.items()])
|
||||
print()
|
||||
|
||||
# reset for follow-up tests
|
||||
|
@ -219,7 +216,7 @@ def run_pipeline():
|
|||
run_el_toy_example(nlp=nlp_2)
|
||||
print()
|
||||
|
||||
if test_nlp_io:
|
||||
if to_write_nlp:
|
||||
print()
|
||||
print("STEP 9: testing NLP IO", datetime.datetime.now())
|
||||
print()
|
||||
|
@ -229,9 +226,10 @@ def run_pipeline():
|
|||
print("reading from", NLP_2_DIR)
|
||||
nlp_3 = spacy.load(NLP_2_DIR)
|
||||
|
||||
print()
|
||||
print("running toy example with NLP 2")
|
||||
run_el_toy_example(nlp=nlp_3)
|
||||
if to_read_nlp:
|
||||
print()
|
||||
print("running toy example with NLP 2")
|
||||
run_el_toy_example(nlp=nlp_3)
|
||||
|
||||
print()
|
||||
print("STOP", datetime.datetime.now())
|
||||
|
@ -270,6 +268,80 @@ def _measure_accuracy(data, el_pipe):
|
|||
except Exception as e:
|
||||
print("Error assessing accuracy", e)
|
||||
|
||||
acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label)
|
||||
return acc, acc_by_label
|
||||
|
||||
|
||||
def _measure_baselines(data, kb):
|
||||
random_correct_by_label = dict()
|
||||
random_incorrect_by_label = dict()
|
||||
|
||||
oracle_correct_by_label = dict()
|
||||
oracle_incorrect_by_label = dict()
|
||||
|
||||
prior_correct_by_label = dict()
|
||||
prior_incorrect_by_label = dict()
|
||||
|
||||
docs = [d for d, g in data if len(d) > 0]
|
||||
golds = [g for d, g in data if len(d) > 0]
|
||||
|
||||
for doc, gold in zip(docs, golds):
|
||||
try:
|
||||
correct_entries_per_article = dict()
|
||||
for entity in gold.links:
|
||||
start, end, gold_kb = entity
|
||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||
|
||||
for ent in doc.ents:
|
||||
ent_label = ent.label_
|
||||
start = ent.start_char
|
||||
end = ent.end_char
|
||||
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
||||
|
||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||
if gold_entity is not None:
|
||||
candidates = kb.get_candidates(ent.text)
|
||||
oracle_candidate = ""
|
||||
best_candidate = ""
|
||||
random_candidate = ""
|
||||
if candidates:
|
||||
scores = list()
|
||||
|
||||
for c in candidates:
|
||||
scores.append(c.prior_prob)
|
||||
if c.entity_ == gold_entity:
|
||||
oracle_candidate = c.entity_
|
||||
|
||||
best_index = scores.index(max(scores))
|
||||
best_candidate = candidates[best_index].entity_
|
||||
random_candidate = random.choice(candidates).entity_
|
||||
|
||||
if gold_entity == best_candidate:
|
||||
prior_correct_by_label[ent_label] = prior_correct_by_label.get(ent_label, 0) + 1
|
||||
else:
|
||||
prior_incorrect_by_label[ent_label] = prior_incorrect_by_label.get(ent_label, 0) + 1
|
||||
|
||||
if gold_entity == random_candidate:
|
||||
random_correct_by_label[ent_label] = random_correct_by_label.get(ent_label, 0) + 1
|
||||
else:
|
||||
random_incorrect_by_label[ent_label] = random_incorrect_by_label.get(ent_label, 0) + 1
|
||||
|
||||
if gold_entity == oracle_candidate:
|
||||
oracle_correct_by_label[ent_label] = oracle_correct_by_label.get(ent_label, 0) + 1
|
||||
else:
|
||||
oracle_incorrect_by_label[ent_label] = oracle_incorrect_by_label.get(ent_label, 0) + 1
|
||||
|
||||
except Exception as e:
|
||||
print("Error assessing accuracy", e)
|
||||
|
||||
acc_prior, acc_prior_by_label = calculate_acc(prior_correct_by_label, prior_incorrect_by_label)
|
||||
acc_random, acc_random_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label)
|
||||
acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label)
|
||||
|
||||
return acc_random, acc_random_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label
|
||||
|
||||
|
||||
def calculate_acc(correct_by_label, incorrect_by_label):
|
||||
acc_by_label = dict()
|
||||
total_correct = 0
|
||||
total_incorrect = 0
|
||||
|
@ -303,18 +375,25 @@ def run_el_toy_example(nlp):
|
|||
"The main character in Doug's novel is the man Arthur Dent, " \
|
||||
"but Douglas doesn't write about George Washington or Homer Simpson."
|
||||
doc = nlp(text)
|
||||
|
||||
print(text)
|
||||
for ent in doc.ents:
|
||||
print("ent", ent.text, ent.label_, ent.kb_id_)
|
||||
|
||||
print()
|
||||
|
||||
# Q4426480 is her husband, Q3568763 her tutor
|
||||
text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine."\
|
||||
"Ada Lovelace loved her husband William King dearly. " \
|
||||
"Ada Lovelace was tutored by her favorite physics tutor William King."
|
||||
# Q4426480 is her husband
|
||||
text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine. "\
|
||||
"She loved her husband William King dearly. "
|
||||
doc = nlp(text)
|
||||
print(text)
|
||||
for ent in doc.ents:
|
||||
print("ent", ent.text, ent.label_, ent.kb_id_)
|
||||
print()
|
||||
|
||||
# Q3568763 is her tutor
|
||||
text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine. "\
|
||||
"She was tutored by her favorite physics tutor William King."
|
||||
doc = nlp(text)
|
||||
print(text)
|
||||
for ent in doc.ents:
|
||||
print("ent", ent.text, ent.label_, ent.kb_id_)
|
||||
|
||||
|
|
Loading…
Reference in New Issue