Changes to wiki_entity_linker (#4235)

* Changes to wiki_entity_linker

* No more f-strings

* Make some requested changes

* Add back option to get descriptions from wd not wp

* Fix logs

* Address comments and clean evaluation

* Remove type hints

* Refactor evaluation, add back metrics by label

* Address comments

* Log training performance as well as dev
This commit is contained in:
Euan Dowers 2019-09-13 17:03:57 +02:00 committed by Ines Montani
parent 2ae5db580e
commit a6830d60e8
9 changed files with 709 additions and 771 deletions

View File

@ -0,0 +1,11 @@
TRAINING_DATA_FILE = "gold_entities.jsonl"
KB_FILE = "kb"
KB_MODEL_DIR = "nlp_kb"
OUTPUT_MODEL_DIR = "nlp"
PRIOR_PROB_PATH = "prior_prob.csv"
ENTITY_DEFS_PATH = "entity_defs.csv"
ENTITY_FREQ_PATH = "entity_freq.csv"
ENTITY_DESCR_PATH = "entity_descriptions.csv"
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'

View File

@ -0,0 +1,200 @@
import logging
import random
from collections import defaultdict
logger = logging.getLogger(__name__)
class Metrics(object):
true_pos = 0
false_pos = 0
false_neg = 0
def update_results(self, true_entity, candidate):
candidate_is_correct = true_entity == candidate
# Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL")
# Therefore, if candidate_is_correct then we have a true positive and never a true negative
self.true_pos += candidate_is_correct
self.false_neg += not candidate_is_correct
if candidate not in {"", "NIL"}:
self.false_pos += not candidate_is_correct
def calculate_precision(self):
if self.true_pos == 0:
return 0.0
else:
return self.true_pos / (self.true_pos + self.false_pos)
def calculate_recall(self):
if self.true_pos == 0:
return 0.0
else:
return self.true_pos / (self.true_pos + self.false_neg)
class EvaluationResults(object):
def __init__(self):
self.metrics = Metrics()
self.metrics_by_label = defaultdict(Metrics)
def update_metrics(self, ent_label, true_entity, candidate):
self.metrics.update_results(true_entity, candidate)
self.metrics_by_label[ent_label].update_results(true_entity, candidate)
def increment_false_negatives(self):
self.metrics.false_neg += 1
def report_metrics(self, model_name):
model_str = model_name.title()
recall = self.metrics.calculate_recall()
precision = self.metrics.calculate_precision()
return ("{}: ".format(model_str) +
"Recall = {} | ".format(round(recall, 3)) +
"Precision = {} | ".format(round(precision, 3)) +
"Precision by label = {}".format({k: v.calculate_precision()
for k, v in self.metrics_by_label.items()}))
class BaselineResults(object):
def __init__(self):
self.random = EvaluationResults()
self.prior = EvaluationResults()
self.oracle = EvaluationResults()
def report_accuracy(self, model):
results = getattr(self, model)
return results.report_metrics(model)
def update_baselines(self, true_entity, ent_label, random_candidate, prior_candidate, oracle_candidate):
self.oracle.update_metrics(ent_label, true_entity, oracle_candidate)
self.prior.update_metrics(ent_label, true_entity, prior_candidate)
self.random.update_metrics(ent_label, true_entity, random_candidate)
def measure_performance(dev_data, kb, el_pipe):
baseline_accuracies = measure_baselines(
dev_data, kb
)
logger.info(baseline_accuracies.report_accuracy("random"))
logger.info(baseline_accuracies.report_accuracy("prior"))
logger.info(baseline_accuracies.report_accuracy("oracle"))
# using only context
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = False
results = get_eval_results(dev_data, el_pipe)
logger.info(results.report_metrics("context only"))
# measuring combined accuracy (prior + context)
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = True
results = get_eval_results(dev_data, el_pipe)
logger.info(results.report_metrics("context and prior"))
def get_eval_results(data, el_pipe=None):
# If the docs in the data require further processing with an entity linker, set el_pipe
from tqdm import tqdm
docs = []
golds = []
for d, g in tqdm(data, leave=False):
if len(d) > 0:
golds.append(g)
if el_pipe is not None:
docs.append(el_pipe(d))
else:
docs.append(d)
results = EvaluationResults()
for doc, gold in zip(docs, golds):
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
try:
correct_entries_per_article = dict()
for entity, kb_dict in gold.links.items():
start, end = entity
# only evaluating on positive examples
for gold_kb, value in kb_dict.items():
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
if offset not in tagged_entries_per_article:
results.increment_false_negatives()
for ent in doc.ents:
ent_label = ent.label_
pred_entity = ent.kb_id_
start = ent.start_char
end = ent.end_char
offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
results.update_metrics(ent_label, gold_entity, pred_entity)
except Exception as e:
logging.error("Error assessing accuracy " + str(e))
return results
def measure_baselines(data, kb):
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
counts_d = dict()
baseline_results = BaselineResults()
docs = [d for d, g in data if len(d) > 0]
golds = [g for d, g in data if len(d) > 0]
for doc, gold in zip(docs, golds):
correct_entries_per_article = dict()
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
for entity, kb_dict in gold.links.items():
start, end = entity
for gold_kb, value in kb_dict.items():
# only evaluating on positive examples
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
if offset not in tagged_entries_per_article:
baseline_results.random.increment_false_negatives()
baseline_results.oracle.increment_false_negatives()
baseline_results.prior.increment_false_negatives()
for ent in doc.ents:
ent_label = ent.label_
start = ent.start_char
end = ent.end_char
offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
candidates = kb.get_candidates(ent.text)
oracle_candidate = ""
best_candidate = ""
random_candidate = ""
if candidates:
scores = []
for c in candidates:
scores.append(c.prior_prob)
if c.entity_ == gold_entity:
oracle_candidate = c.entity_
best_index = scores.index(max(scores))
best_candidate = candidates[best_index].entity_
random_candidate = random.choice(candidates).entity_
baseline_results.update_baselines(gold_entity, ent_label,
random_candidate, best_candidate, oracle_candidate)
return baseline_results
def _offset(start, end):
return "{}_{}".format(start, end)

View File

@ -1,12 +1,20 @@
# coding: utf-8
from __future__ import unicode_literals
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
from bin.wiki_entity_linking import wikidata_processor as wd, wikipedia_processor as wp
import csv
import logging
import spacy
import sys
from spacy.kb import KnowledgeBase
import csv
import datetime
from bin.wiki_entity_linking import wikipedia_processor as wp
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
csv.field_size_limit(sys.maxsize)
logger = logging.getLogger(__name__)
def create_kb(
@ -14,52 +22,73 @@ def create_kb(
max_entities_per_alias,
min_entity_freq,
min_occ,
entity_def_output,
entity_descr_output,
entity_def_input,
entity_descr_path,
count_input,
prior_prob_input,
wikidata_input,
entity_vector_length,
limit=None,
read_raw_data=True,
):
# Create the knowledge base from Wikidata entries
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
# read the mappings from file
title_to_id = get_entity_to_id(entity_def_input)
id_to_descr = get_id_to_description(entity_descr_path)
# check the length of the nlp vectors
if "vectors" in nlp.meta and nlp.vocab.vectors.size:
input_dim = nlp.vocab.vectors_length
print("Loaded pre-trained vectors of size %s" % input_dim)
logger.info("Loaded pre-trained vectors of size %s" % input_dim)
else:
raise ValueError(
"The `nlp` object should have access to pre-trained word vectors, "
" cf. https://spacy.io/usage/models#languages."
)
# disable this part of the pipeline when rerunning the KB generation from preprocessed files
if read_raw_data:
print()
print(now(), " * read wikidata entities:")
title_to_id, id_to_descr = wd.read_wikidata_entities_json(
wikidata_input, limit=limit
)
# write the title-ID and ID-description mappings to file
_write_entity_files(
entity_def_output, entity_descr_output, title_to_id, id_to_descr
)
else:
# read the mappings from file
title_to_id = get_entity_to_id(entity_def_output)
id_to_descr = get_id_to_description(entity_descr_output)
print()
print(now(), " * get entity frequencies:")
print()
logger.info("Get entity frequencies")
entity_frequencies = wp.get_all_frequencies(count_input=count_input)
logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq))
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(
title_to_id,
id_to_descr,
entity_frequencies,
min_entity_freq
)
logger.info("Left with {} entities".format(len(description_list)))
logger.info("Train entity encoder")
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
encoder.train(description_list=description_list, to_print=True)
logger.info("Get entity embeddings:")
embeddings = encoder.apply_encoder(description_list)
logger.info("Adding {} entities".format(len(entity_list)))
kb.set_entities(
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
)
logger.info("Adding aliases")
_add_aliases(
kb,
title_to_id=filtered_title_to_id,
max_entities_per_alias=max_entities_per_alias,
min_occ=min_occ,
prior_prob_input=prior_prob_input,
)
logger.info("KB size: {} entities, {} aliases".format(
kb.get_size_entities(),
kb.get_size_aliases()))
logger.info("Done with kb")
return kb
def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
min_entity_freq: int = 10):
filtered_title_to_id = dict()
entity_list = []
description_list = []
@ -72,58 +101,7 @@ def create_kb(
description_list.append(desc)
frequency_list.append(freq)
filtered_title_to_id[title] = entity
print(len(title_to_id.keys()), "original titles")
kept_nr = len(filtered_title_to_id.keys())
print("kept", kept_nr, "entities with min. frequency", min_entity_freq)
print()
print(now(), " * train entity encoder:")
print()
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
encoder.train(description_list=description_list, to_print=True)
print()
print(now(), " * get entity embeddings:")
print()
embeddings = encoder.apply_encoder(description_list)
print(now(), " * adding", len(entity_list), "entities")
kb.set_entities(
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
)
alias_cnt = _add_aliases(
kb,
title_to_id=filtered_title_to_id,
max_entities_per_alias=max_entities_per_alias,
min_occ=min_occ,
prior_prob_input=prior_prob_input,
)
print()
print(now(), " * adding", alias_cnt, "aliases")
print()
print()
print("# of entities in kb:", kb.get_size_entities())
print("# of aliases in kb:", kb.get_size_aliases())
print(now(), "Done with kb")
return kb
def _write_entity_files(
entity_def_output, entity_descr_output, title_to_id, id_to_descr
):
with entity_def_output.open("w", encoding="utf8") as id_file:
id_file.write("WP_title" + "|" + "WD_id" + "\n")
for title, qid in title_to_id.items():
id_file.write(title + "|" + str(qid) + "\n")
with entity_descr_output.open("w", encoding="utf8") as descr_file:
descr_file.write("WD_id" + "|" + "description" + "\n")
for qid, descr in id_to_descr.items():
descr_file.write(str(qid) + "|" + descr + "\n")
return filtered_title_to_id, entity_list, description_list, frequency_list
def get_entity_to_id(entity_def_output):
@ -137,9 +115,9 @@ def get_entity_to_id(entity_def_output):
return entity_to_id
def get_id_to_description(entity_descr_output):
def get_id_to_description(entity_descr_path):
id_to_desc = dict()
with entity_descr_output.open("r", encoding="utf8") as csvfile:
with entity_descr_path.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
@ -150,7 +128,6 @@ def get_id_to_description(entity_descr_output):
def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input):
wp_titles = title_to_id.keys()
cnt = 0
# adding aliases with prior probabilities
# we can read this file sequentially, it's sorted by alias, and then by count
@ -187,9 +164,8 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
entities=selected_entities,
probabilities=prior_probs,
)
cnt += 1
except ValueError as e:
print(e)
logger.error(e)
total_count = 0
counts = []
entities = []
@ -202,8 +178,12 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
previous_alias = new_alias
line = prior_file.readline()
return cnt
def now():
return datetime.datetime.now()
def read_nlp_kb(model_dir, kb_file):
nlp = spacy.load(model_dir)
kb = KnowledgeBase(vocab=nlp.vocab)
kb.load_bulk(kb_file)
logger.info("kb entities: {}".format(kb.get_size_entities()))
logger.info("kb aliases: {}".format(kb.get_size_aliases()))
return nlp, kb

View File

@ -1,6 +1,7 @@
# coding: utf-8
from random import shuffle
import logging
import numpy as np
from spacy._ml import zero_init, create_default_optimizer
@ -10,6 +11,8 @@ from thinc.v2v import Model
from thinc.api import chain
from thinc.neural._classes.affine import Affine
logger = logging.getLogger(__name__)
class EntityEncoder:
"""
@ -50,21 +53,19 @@ class EntityEncoder:
start = start + batch_size
stop = min(stop + batch_size, len(description_list))
print("encoded:", stop, "entities")
logger.info("encoded: {} entities".format(stop))
return encodings
def train(self, description_list, to_print=False):
processed, loss = self._train_model(description_list)
if to_print:
print(
"Trained entity descriptions on",
processed,
"(non-unique) entities across",
self.epochs,
"epochs",
logger.info(
"Trained entity descriptions on {} ".format(processed) +
"(non-unique) entities across {} ".format(self.epochs) +
"epochs"
)
print("Final loss:", loss)
logger.info("Final loss: {}".format(loss))
def _train_model(self, description_list):
best_loss = 1.0
@ -93,7 +94,7 @@ class EntityEncoder:
loss = self._update(batch)
if batch_nr % 25 == 0:
print("loss:", loss)
logger.info("loss: {} ".format(loss))
processed += len(batch)
# in general, continue training if we haven't reached our ideal min yet

View File

@ -1,10 +1,13 @@
# coding: utf-8
from __future__ import unicode_literals
import logging
import random
import re
import bz2
import datetime
import json
from functools import partial
from spacy.gold import GoldParse
from bin.wiki_entity_linking import kb_creator
@ -15,18 +18,30 @@ Gold-standard entities are stored in one file in standoff format (by character o
"""
ENTITY_FILE = "gold_entities.csv"
logger = logging.getLogger(__name__)
def now():
return datetime.datetime.now()
def create_training(wikipedia_input, entity_def_input, training_output, limit=None):
def create_training_examples_and_descriptions(wikipedia_input,
entity_def_input,
description_output,
training_output,
parse_descriptions,
limit=None):
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
_process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=limit)
_process_wikipedia_texts(wikipedia_input,
wp_to_id,
description_output,
training_output,
parse_descriptions,
limit)
def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None):
def _process_wikipedia_texts(wikipedia_input,
wp_to_id,
output,
training_output,
parse_descriptions,
limit=None):
"""
Read the XML wikipedia data to parse out training data:
raw text data + positive instances
@ -35,29 +50,21 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
read_ids = set()
entityfile_loc = training_output / ENTITY_FILE
with entityfile_loc.open("w", encoding="utf8") as entityfile:
# write entity training header file
_write_training_entity(
outputfile=entityfile,
article_id="article_id",
alias="alias",
entity="WD_id",
start="start",
end="end",
)
with output.open("a", encoding="utf8") as descr_file, training_output.open("w", encoding="utf8") as entity_file:
if parse_descriptions:
_write_training_description(descr_file, "WD_id", "description")
with bz2.open(wikipedia_input, mode="rb") as file:
line = file.readline()
cnt = 0
article_count = 0
article_text = ""
article_title = None
article_id = None
reading_text = False
reading_revision = False
while line and (not limit or cnt < limit):
if cnt % 1000000 == 0:
print(now(), "processed", cnt, "lines of Wikipedia dump")
logger.info("Processed {} articles".format(article_count))
for line in file:
clean_line = line.strip().decode("utf-8")
if clean_line == "<revision>":
@ -70,28 +77,32 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
article_text = ""
article_title = None
article_id = None
# finished reading this page
elif clean_line == "</page>":
if article_id:
try:
_process_wp_text(
wp_to_id,
entityfile,
article_id,
article_title,
article_text.strip(),
training_output,
)
except Exception as e:
print(
"Error processing article", article_id, article_title, e
)
else:
print(
"Done processing a page, but couldn't find an article_id ?",
clean_text, entities = _process_wp_text(
article_title,
article_text,
wp_to_id
)
if clean_text is not None and entities is not None:
_write_training_entities(entity_file,
article_id,
clean_text,
entities)
if article_title in wp_to_id and parse_descriptions:
description = " ".join(clean_text[:1000].split(" ")[:-1])
_write_training_description(
descr_file,
wp_to_id[article_title],
description
)
article_count += 1
if article_count % 10000 == 0:
logger.info("Processed {} articles".format(article_count))
if limit and article_count >= limit:
break
article_text = ""
article_title = None
article_id = None
@ -115,7 +126,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
if ids:
article_id = ids[0]
if article_id in read_ids:
print(
logger.info(
"Found duplicate article ID", article_id, clean_line
) # This should never happen ...
read_ids.add(article_id)
@ -125,115 +136,10 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
titles = title_regex.search(clean_line)
if titles:
article_title = titles[0].strip()
line = file.readline()
cnt += 1
print(now(), "processed", cnt, "lines of Wikipedia dump")
logger.info("Finished. Processed {} articles".format(article_count))
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
def _process_wp_text(
wp_to_id, entityfile, article_id, article_title, article_text, training_output
):
found_entities = False
# ignore meta Wikipedia pages
if article_title.startswith("Wikipedia:"):
return
# remove the text tags
text = text_regex.search(article_text).group(0)
# stop processing if this is a redirect page
if text.startswith("#REDIRECT"):
return
# get the raw text without markup etc, keeping only interwiki links
clean_text = _get_clean_wp_text(text)
# read the text char by char to get the right offsets for the interwiki links
final_text = ""
open_read = 0
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
entity_buffer = ""
mention_buffer = ""
for index, letter in enumerate(clean_text):
if letter == "[":
open_read += 1
elif letter == "]":
open_read -= 1
elif letter == "|":
if reading_text:
final_text += letter
# switch from reading entity to mention in the [[entity|mention]] pattern
elif reading_entity:
reading_text = False
reading_entity = False
reading_mention = True
else:
reading_special_case = True
else:
if reading_entity:
entity_buffer += letter
elif reading_mention:
mention_buffer += letter
elif reading_text:
final_text += letter
else:
raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
if open_read > 2:
reading_special_case = True
if open_read == 2 and reading_text:
reading_text = False
reading_entity = True
reading_mention = False
# we just finished reading an entity
if open_read == 0 and not reading_text:
if "#" in entity_buffer or entity_buffer.startswith(":"):
reading_special_case = True
# Ignore cases with nested structures like File: handles etc
if not reading_special_case:
if not mention_buffer:
mention_buffer = entity_buffer
start = len(final_text)
end = start + len(mention_buffer)
qid = wp_to_id.get(entity_buffer, None)
if qid:
_write_training_entity(
outputfile=entityfile,
article_id=article_id,
alias=mention_buffer,
entity=qid,
start=start,
end=end,
)
found_entities = True
final_text += mention_buffer
entity_buffer = ""
mention_buffer = ""
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
if found_entities:
_write_training_article(
article_id=article_id,
clean_text=final_text,
training_output=training_output,
)
info_regex = re.compile(r"{[^{]*?}")
htlm_regex = re.compile(r"&lt;!--[^-]*--&gt;")
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
@ -242,6 +148,29 @@ ref_regex = re.compile(r"&lt;ref.*?&gt;") # non-greedy
ref_2_regex = re.compile(r"&lt;/ref.*?&gt;") # non-greedy
def _process_wp_text(article_title, article_text, wp_to_id):
# ignore meta Wikipedia pages
if (
article_title.startswith("Wikipedia:") or
article_title.startswith("Kategori:")
):
return None, None
# remove the text tags
text_search = text_regex.search(article_text)
if text_search is None:
return None, None
text = text_search.group(0)
# stop processing if this is a redirect page
if text.startswith("#REDIRECT"):
return None, None
# get the raw text without markup etc, keeping only interwiki links
clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
return clean_text, entities
def _get_clean_wp_text(article_text):
clean_text = article_text.strip()
@ -300,130 +229,167 @@ def _get_clean_wp_text(article_text):
return clean_text.strip()
def _write_training_article(article_id, clean_text, training_output):
file_loc = training_output / "{}.txt".format(article_id)
with file_loc.open("w", encoding="utf8") as outputfile:
outputfile.write(clean_text)
def _remove_links(clean_text, wp_to_id):
# read the text char by char to get the right offsets for the interwiki links
entities = []
final_text = ""
open_read = 0
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
entity_buffer = ""
mention_buffer = ""
for index, letter in enumerate(clean_text):
if letter == "[":
open_read += 1
elif letter == "]":
open_read -= 1
elif letter == "|":
if reading_text:
final_text += letter
# switch from reading entity to mention in the [[entity|mention]] pattern
elif reading_entity:
reading_text = False
reading_entity = False
reading_mention = True
else:
reading_special_case = True
else:
if reading_entity:
entity_buffer += letter
elif reading_mention:
mention_buffer += letter
elif reading_text:
final_text += letter
else:
raise ValueError("Not sure at point", clean_text[index - 2: index + 2])
if open_read > 2:
reading_special_case = True
if open_read == 2 and reading_text:
reading_text = False
reading_entity = True
reading_mention = False
# we just finished reading an entity
if open_read == 0 and not reading_text:
if "#" in entity_buffer or entity_buffer.startswith(":"):
reading_special_case = True
# Ignore cases with nested structures like File: handles etc
if not reading_special_case:
if not mention_buffer:
mention_buffer = entity_buffer
start = len(final_text)
end = start + len(mention_buffer)
qid = wp_to_id.get(entity_buffer, None)
if qid:
entities.append((mention_buffer, qid, start, end))
final_text += mention_buffer
entity_buffer = ""
mention_buffer = ""
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
return final_text, entities
def _write_training_entity(outputfile, article_id, alias, entity, start, end):
line = "{}|{}|{}|{}|{}\n".format(article_id, alias, entity, start, end)
def _write_training_description(outputfile, qid, description):
if description is not None:
line = str(qid) + "|" + description + "\n"
outputfile.write(line)
def _write_training_entities(outputfile, article_id, clean_text, entities):
entities_data = [{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} for ent in entities]
line = json.dumps(
{
"article_id": article_id,
"clean_text": clean_text,
"entities": entities_data
},
ensure_ascii=False) + "\n"
outputfile.write(line)
def read_training(nlp, entity_file_path, dev, limit, kb):
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
For training,, it will include negative training examples by using the candidate generator,
and it will only keep positive training examples that can be found by using the candidate generator.
For testing, it will include all positive examples only."""
from tqdm import tqdm
data = []
num_entities = 0
get_gold_parse = partial(_get_gold_parse, dev=dev, kb=kb)
logger.info("Reading {} data with limit {}".format('dev' if dev else 'train', limit))
with entity_file_path.open("r", encoding="utf8") as file:
with tqdm(total=limit, leave=False) as pbar:
for i, line in enumerate(file):
example = json.loads(line)
article_id = example["article_id"]
clean_text = example["clean_text"]
entities = example["entities"]
if dev != is_dev(article_id) or len(clean_text) >= 30000:
continue
doc = nlp(clean_text)
gold = get_gold_parse(doc, entities)
if gold and len(gold.links) > 0:
data.append((doc, gold))
num_entities += len(gold.links)
pbar.update(len(gold.links))
if limit and num_entities >= limit:
break
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
return data
def _get_gold_parse(doc, entities, dev, kb):
gold_entities = {}
tagged_ent_positions = set(
[(ent.start_char, ent.end_char) for ent in doc.ents]
)
for entity in entities:
entity_id = entity["entity"]
alias = entity["alias"]
start = entity["start"]
end = entity["end"]
candidates = kb.get_candidates(alias)
candidate_ids = [
c.entity_ for c in candidates
]
should_add_ent = (
dev or
(
(start, end) in tagged_ent_positions and
entity_id in candidate_ids and
len(candidates) > 1
)
)
if should_add_ent:
value_by_id = {entity_id: 1.0}
if not dev:
random.shuffle(candidate_ids)
value_by_id.update({
kb_id: 0.0
for kb_id in candidate_ids
if kb_id != entity_id
})
gold_entities[(start, end)] = value_by_id
return GoldParse(doc, links=gold_entities)
def is_dev(article_id):
return article_id.endswith("3")
def read_training(nlp, training_dir, dev, limit, kb=None):
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
When kb is provided (for training), it will include negative training examples by using the candidate generator,
and it will only keep positive training examples that can be found in the KB.
When kb=None (for testing), it will include all positive examples only."""
entityfile_loc = training_dir / ENTITY_FILE
data = []
# assume the data is written sequentially, so we can reuse the article docs
current_article_id = None
current_doc = None
ents_by_offset = dict()
skip_articles = set()
total_entities = 0
with entityfile_loc.open("r", encoding="utf8") as file:
for line in file:
if not limit or len(data) < limit:
fields = line.replace("\n", "").split(sep="|")
article_id = fields[0]
alias = fields[1]
wd_id = fields[2]
start = fields[3]
end = fields[4]
if (
dev == is_dev(article_id)
and article_id != "article_id"
and article_id not in skip_articles
):
if not current_doc or (current_article_id != article_id):
# parse the new article text
file_name = article_id + ".txt"
try:
training_file = training_dir / file_name
with training_file.open("r", encoding="utf8") as f:
text = f.read()
# threshold for convenience / speed of processing
if len(text) < 30000:
current_doc = nlp(text)
current_article_id = article_id
ents_by_offset = dict()
for ent in current_doc.ents:
sent_length = len(ent.sent)
# custom filtering to avoid too long or too short sentences
if 5 < sent_length < 100:
offset = "{}_{}".format(
ent.start_char, ent.end_char
)
ents_by_offset[offset] = ent
else:
skip_articles.add(article_id)
current_doc = None
except Exception as e:
print("Problem parsing article", article_id, e)
skip_articles.add(article_id)
# repeat checking this condition in case an exception was thrown
if current_doc and (current_article_id == article_id):
offset = "{}_{}".format(start, end)
found_ent = ents_by_offset.get(offset, None)
if found_ent:
if found_ent.text != alias:
skip_articles.add(article_id)
current_doc = None
else:
sent = found_ent.sent.as_doc()
gold_start = int(start) - found_ent.sent.start_char
gold_end = int(end) - found_ent.sent.start_char
gold_entities = {}
found_useful = False
for ent in sent.ents:
entry = (ent.start_char, ent.end_char)
gold_entry = (gold_start, gold_end)
if entry == gold_entry:
# add both pos and neg examples (in random order)
# this will exclude examples not in the KB
if kb:
value_by_id = {}
candidates = kb.get_candidates(alias)
candidate_ids = [
c.entity_ for c in candidates
]
random.shuffle(candidate_ids)
for kb_id in candidate_ids:
found_useful = True
if kb_id != wd_id:
value_by_id[kb_id] = 0.0
else:
value_by_id[kb_id] = 1.0
gold_entities[entry] = value_by_id
# if no KB, keep all positive examples
else:
found_useful = True
value_by_id = {wd_id: 1.0}
gold_entities[entry] = value_by_id
# currently feeding the gold data one entity per sentence at a time
# setting all other entities to empty gold dictionary
else:
gold_entities[entry] = {}
if found_useful:
gold = GoldParse(doc=sent, links=gold_entities)
data.append((sent, gold))
total_entities += 1
if len(data) % 2500 == 0:
print(" -read", total_entities, "entities")
print(" -read", total_entities, "entities")
return data

View File

@ -13,27 +13,25 @@ from https://dumps.wikimedia.org/enwiki/latest/
"""
from __future__ import unicode_literals
import datetime
import logging
from pathlib import Path
import plac
from bin.wiki_entity_linking import wikipedia_processor as wp
from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd
from bin.wiki_entity_linking import kb_creator
from bin.wiki_entity_linking import training_set_creator
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT
from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH
import spacy
from spacy import Errors
def now():
return datetime.datetime.now()
logger = logging.getLogger(__name__)
@plac.annotations(
wd_json=("Path to the downloaded WikiData JSON dump.", "positional", None, Path),
wp_xml=("Path to the downloaded Wikipedia XML dump.", "positional", None, Path),
output_dir=("Output directory", "positional", None, Path),
model=("Model name, should include pretrained vectors.", "positional", None, str),
model=("Model name or path, should include pretrained vectors.", "positional", None, str),
max_per_alias=("Max. # entities per alias (default 10)", "option", "a", int),
min_freq=("Min. count of an entity in the corpus (default 20)", "option", "f", int),
min_pair=("Min. count of entity-alias pairs (default 5)", "option", "c", int),
@ -41,7 +39,9 @@ def now():
loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
descriptions_from_wikipedia=("Flag for using wp descriptions not wd", "flag", "wp"),
limit=("Optional threshold to limit lines read from dumps", "option", "l", int),
lang=("Optional language for which to get wikidata titles. Defaults to 'en'", "option", "la", str),
)
def main(
wd_json,
@ -55,20 +55,29 @@ def main(
loc_prior_prob=None,
loc_entity_defs=None,
loc_entity_desc=None,
descriptions_from_wikipedia=False,
limit=None,
lang="en",
):
print(now(), "Creating KB with Wikipedia and WikiData")
print()
entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH
entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH
entity_freq_path = output_dir / ENTITY_FREQ_PATH
prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH
training_entities_path = output_dir / TRAINING_DATA_FILE
kb_path = output_dir / KB_FILE
logger.info("Creating KB with Wikipedia and WikiData")
if limit is not None:
print("Warning: reading only", limit, "lines of Wikipedia/Wikidata dumps.")
logger.warning("Warning: reading only {} lines of Wikipedia/Wikidata dumps.".format(limit))
# STEP 0: set up IO
if not output_dir.exists():
output_dir.mkdir()
output_dir.mkdir(parents=True)
# STEP 1: create the NLP object
print(now(), "STEP 1: loaded model", model)
logger.info("STEP 1: Loading model {}".format(model))
nlp = spacy.load(model)
# check the length of the nlp vectors
@ -79,64 +88,68 @@ def main(
)
# STEP 2: create prior probabilities from WP
print()
if loc_prior_prob:
print(now(), "STEP 2: reading prior probabilities from", loc_prior_prob)
else:
if not prior_prob_path.exists():
# It takes about 2h to process 1000M lines of Wikipedia XML dump
loc_prior_prob = output_dir / "prior_prob.csv"
print(now(), "STEP 2: writing prior probabilities at", loc_prior_prob)
wp.read_prior_probs(wp_xml, loc_prior_prob, limit=limit)
logger.info("STEP 2: writing prior probabilities to {}".format(prior_prob_path))
wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit)
logger.info("STEP 2: reading prior probabilities from {}".format(prior_prob_path))
# STEP 3: deduce entity frequencies from WP (takes only a few minutes)
print()
print(now(), "STEP 3: calculating entity frequencies")
loc_entity_freq = output_dir / "entity_freq.csv"
wp.write_entity_counts(loc_prior_prob, loc_entity_freq, to_print=False)
logger.info("STEP 3: calculating entity frequencies")
wp.write_entity_counts(prior_prob_path, entity_freq_path, to_print=False)
loc_kb = output_dir / "kb"
# STEP 4: reading entity descriptions and definitions from WikiData or from file
print()
if loc_entity_defs and loc_entity_desc:
read_raw = False
print(now(), "STEP 4a: reading entity definitions from", loc_entity_defs)
print(now(), "STEP 4b: reading entity descriptions from", loc_entity_desc)
else:
# STEP 4: reading definitions and (possibly) descriptions from WikiData or from file
message = " and descriptions" if not descriptions_from_wikipedia else ""
if (not entity_defs_path.exists()) or (not descriptions_from_wikipedia and not entity_descr_path.exists()):
# It takes about 10h to process 55M lines of Wikidata JSON dump
read_raw = True
loc_entity_defs = output_dir / "entity_defs.csv"
loc_entity_desc = output_dir / "entity_descriptions.csv"
print(now(), "STEP 4: parsing wikidata for entity definitions and descriptions")
logger.info("STEP 4: parsing wikidata for entity definitions" + message)
title_to_id, id_to_descr = wd.read_wikidata_entities_json(
wd_json,
limit,
to_print=False,
lang=lang,
parse_descriptions=(not descriptions_from_wikipedia),
)
wd.write_entity_files(entity_defs_path, title_to_id)
if not descriptions_from_wikipedia:
wd.write_entity_description_files(entity_descr_path, id_to_descr)
logger.info("STEP 4: read entity definitions" + message)
# STEP 5: creating the actual KB
# STEP 5: Getting gold entities from wikipedia
message = " and descriptions" if descriptions_from_wikipedia else ""
if (not training_entities_path.exists()) or (descriptions_from_wikipedia and not entity_descr_path.exists()):
logger.info("STEP 5: parsing wikipedia for gold entities" + message)
training_set_creator.create_training_examples_and_descriptions(
wp_xml,
entity_defs_path,
entity_descr_path,
training_entities_path,
parse_descriptions=descriptions_from_wikipedia,
limit=limit,
)
logger.info("STEP 5: read gold entities" + message)
# STEP 6: creating the actual KB
# It takes ca. 30 minutes to pretrain the entity embeddings
print()
print(now(), "STEP 5: creating the KB at", loc_kb)
logger.info("STEP 6: creating the KB at {}".format(kb_path))
kb = kb_creator.create_kb(
nlp=nlp,
max_entities_per_alias=max_per_alias,
min_entity_freq=min_freq,
min_occ=min_pair,
entity_def_output=loc_entity_defs,
entity_descr_output=loc_entity_desc,
count_input=loc_entity_freq,
prior_prob_input=loc_prior_prob,
wikidata_input=wd_json,
entity_def_input=entity_defs_path,
entity_descr_path=entity_descr_path,
count_input=entity_freq_path,
prior_prob_input=prior_prob_path,
entity_vector_length=entity_vector_length,
limit=limit,
read_raw_data=read_raw,
)
if read_raw:
print(" - wrote entity definitions to", loc_entity_defs)
print(" - wrote writing entity descriptions to", loc_entity_desc)
kb.dump(loc_kb)
nlp.to_disk(output_dir / "nlp")
kb.dump(kb_path)
nlp.to_disk(output_dir / KB_MODEL_DIR)
print()
print(now(), "Done!")
logger.info("Done!")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
plac.call(main)

View File

@ -1,17 +1,19 @@
# coding: utf-8
from __future__ import unicode_literals
import bz2
import gzip
import json
import logging
import datetime
logger = logging.getLogger(__name__)
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descriptions=True):
# Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
lang = "en"
site_filter = "enwiki"
site_filter = '{}wiki'.format(lang)
# properties filter (currently disabled to get ALL data)
prop_filter = dict()
@ -24,18 +26,15 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
parse_properties = False
parse_sitelinks = True
parse_labels = False
parse_descriptions = True
parse_aliases = False
parse_claims = False
with bz2.open(wikidata_file, mode="rb") as file:
line = file.readline()
cnt = 0
while line and (not limit or cnt < limit):
if cnt % 1000000 == 0:
print(
datetime.datetime.now(), "processed", cnt, "lines of WikiData JSON dump"
)
with gzip.open(wikidata_file, mode='rb') as file:
for cnt, line in enumerate(file):
if limit and cnt >= limit:
break
if cnt % 500000 == 0:
logger.info("processed {} lines of WikiData dump".format(cnt))
clean_line = line.strip()
if clean_line.endswith(b","):
clean_line = clean_line[:-1]
@ -134,8 +133,19 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
if to_print:
print()
line = file.readline()
cnt += 1
print(datetime.datetime.now(), "processed", cnt, "lines of WikiData JSON dump")
return title_to_id, id_to_descr
def write_entity_files(entity_def_output, title_to_id):
with entity_def_output.open("w", encoding="utf8") as id_file:
id_file.write("WP_title" + "|" + "WD_id" + "\n")
for title, qid in title_to_id.items():
id_file.write(title + "|" + str(qid) + "\n")
def write_entity_description_files(entity_descr_output, id_to_descr):
with entity_descr_output.open("w", encoding="utf8") as descr_file:
descr_file.write("WD_id" + "|" + "description" + "\n")
for qid, descr in id_to_descr.items():
descr_file.write(str(qid) + "|" + descr + "\n")

View File

@ -11,124 +11,84 @@ from https://dumps.wikimedia.org/enwiki/latest/
from __future__ import unicode_literals
import random
import datetime
import logging
from pathlib import Path
import plac
from bin.wiki_entity_linking import training_set_creator
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance, measure_baselines
from bin.wiki_entity_linking.kb_creator import read_nlp_kb
import spacy
from spacy.kb import KnowledgeBase
from spacy.util import minibatch, compounding
def now():
return datetime.datetime.now()
logger = logging.getLogger(__name__)
@plac.annotations(
dir_kb=("Directory with KB, NLP and related files", "positional", None, Path),
output_dir=("Output directory", "option", "o", Path),
loc_training=("Location to training data", "option", "k", Path),
wp_xml=("Path to the downloaded Wikipedia XML dump.", "option", "w", Path),
epochs=("Number of training iterations (default 10)", "option", "e", int),
dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float),
lr=("Learning rate (default 0.005)", "option", "n", float),
l2=("L2 regularization", "option", "r", float),
train_inst=("# training instances (default 90% of all)", "option", "t", int),
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
limit=("Optional threshold to limit lines read from WP dump", "option", "l", int),
)
def main(
dir_kb,
output_dir=None,
loc_training=None,
wp_xml=None,
epochs=10,
dropout=0.5,
lr=0.005,
l2=1e-6,
train_inst=None,
dev_inst=None,
limit=None,
):
print(now(), "Creating Entity Linker with Wikipedia and WikiData")
print()
logger.info("Creating Entity Linker with Wikipedia and WikiData")
output_dir = Path(output_dir) if output_dir else dir_kb
training_path = loc_training if loc_training else output_dir / TRAINING_DATA_FILE
nlp_dir = dir_kb / KB_MODEL_DIR
kb_path = output_dir / KB_FILE
nlp_output_dir = output_dir / OUTPUT_MODEL_DIR
# STEP 0: set up IO
if output_dir and not output_dir.exists():
if not output_dir.exists():
output_dir.mkdir()
# STEP 1 : load the NLP object
nlp_dir = dir_kb / "nlp"
print(now(), "STEP 1: loading model from", nlp_dir)
nlp = spacy.load(nlp_dir)
logger.info("STEP 1: loading model from {}".format(nlp_dir))
nlp, kb = read_nlp_kb(nlp_dir, kb_path)
# check that there is a NER component in the pipeline
if "ner" not in nlp.pipe_names:
raise ValueError("The `nlp` object should have a pre-trained `ner` component.")
# STEP 2 : read the KB
print()
print(now(), "STEP 2: reading the KB from", dir_kb / "kb")
kb = KnowledgeBase(vocab=nlp.vocab)
kb.load_bulk(dir_kb / "kb")
# STEP 2: create a training dataset from WP
logger.info("STEP 2: reading training dataset from {}".format(training_path))
# STEP 3: create a training dataset from WP
print()
if loc_training:
print(now(), "STEP 3: reading training dataset from", loc_training)
else:
if not wp_xml:
raise ValueError(
"Either provide a path to a preprocessed training directory, "
"or to the original Wikipedia XML dump."
)
if output_dir:
loc_training = output_dir / "training_data"
else:
loc_training = dir_kb / "training_data"
if not loc_training.exists():
loc_training.mkdir()
print(now(), "STEP 3: creating training dataset at", loc_training)
if limit is not None:
print("Warning: reading only", limit, "lines of Wikipedia dump.")
loc_entity_defs = dir_kb / "entity_defs.csv"
training_set_creator.create_training(
wikipedia_input=wp_xml,
entity_def_input=loc_entity_defs,
training_output=loc_training,
limit=limit,
)
# STEP 4: parse the training data
print()
print(now(), "STEP 4: parse the training & evaluation data")
# for training, get pos & neg instances that correspond to entries in the kb
print("Parsing training data, limit =", train_inst)
train_data = training_set_creator.read_training(
nlp=nlp, training_dir=loc_training, dev=False, limit=train_inst, kb=kb
nlp=nlp,
entity_file_path=training_path,
dev=False,
limit=train_inst,
kb=kb,
)
print("Training on", len(train_data), "articles")
print()
print("Parsing dev testing data, limit =", dev_inst)
# for testing, get all pos instances, whether or not they are in the kb
dev_data = training_set_creator.read_training(
nlp=nlp, training_dir=loc_training, dev=True, limit=dev_inst, kb=None
nlp=nlp,
entity_file_path=training_path,
dev=True,
limit=dev_inst,
kb=kb,
)
print("Dev testing on", len(dev_data), "articles")
print()
# STEP 5: create and train the entity linking pipe
print()
print(now(), "STEP 5: training Entity Linking pipe")
# STEP 3: create and train the entity linking pipe
logger.info("STEP 3: training Entity Linking pipe")
el_pipe = nlp.create_pipe(
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name}
@ -142,275 +102,70 @@ def main(
optimizer.learn_rate = lr
optimizer.L2 = l2
if not train_data:
print("Did not find any training data")
else:
for itn in range(epochs):
random.shuffle(train_data)
losses = {}
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
batchnr = 0
logger.info("Training on {} articles".format(len(train_data)))
logger.info("Dev testing on {} articles".format(len(dev_data)))
with nlp.disable_pipes(*other_pipes):
for batch in batches:
try:
docs, golds = zip(*batch)
nlp.update(
docs=docs,
golds=golds,
sgd=optimizer,
drop=dropout,
losses=losses,
)
batchnr += 1
except Exception as e:
print("Error updating batch:", e)
if batchnr > 0:
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = True
dev_acc_context, _ = _measure_acc(dev_data, el_pipe)
losses["entity_linker"] = losses["entity_linker"] / batchnr
print(
"Epoch, train loss",
itn,
round(losses["entity_linker"], 2),
" / dev accuracy avg",
round(dev_acc_context, 3),
)
# STEP 6: measure the performance of our trained pipe on an independent dev set
print()
if len(dev_data):
print()
print(now(), "STEP 6: performance measurement of Entity Linking pipe")
print()
counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines(
dev_data, kb
)
print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()]
print("dev accuracy oracle:", round(acc_o, 3), oracle_by_label)
random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()]
print("dev accuracy random:", round(acc_r, 3), random_by_label)
prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()]
print("dev accuracy prior:", round(acc_p, 3), prior_by_label)
# using only context
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = False
dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe)
context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()]
print("dev accuracy context:", round(dev_acc_context, 3), context_by_label)
# measuring combined accuracy (prior + context)
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = True
dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe)
combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()]
print("dev accuracy prior+context:", round(dev_acc_combo, 3), combo_by_label)
# STEP 7: apply the EL pipe on a toy example
print()
print(now(), "STEP 7: applying Entity Linking to toy example")
print()
run_el_toy_example(nlp=nlp)
# STEP 8: write the NLP pipeline (including entity linker) to file
if output_dir:
print()
nlp_loc = output_dir / "nlp"
print(now(), "STEP 8: Writing trained NLP to", nlp_loc)
nlp.to_disk(nlp_loc)
print()
print()
print(now(), "Done!")
def _measure_acc(data, el_pipe=None, error_analysis=False):
# If the docs in the data require further processing with an entity linker, set el_pipe
correct_by_label = dict()
incorrect_by_label = dict()
docs = [d for d, g in data if len(d) > 0]
if el_pipe is not None:
docs = list(el_pipe.pipe(docs))
golds = [g for d, g in data if len(d) > 0]
for doc, gold in zip(docs, golds):
try:
correct_entries_per_article = dict()
for entity, kb_dict in gold.links.items():
start, end = entity
# only evaluating on positive examples
for gold_kb, value in kb_dict.items():
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
for ent in doc.ents:
ent_label = ent.label_
pred_entity = ent.kb_id_
start = ent.start_char
end = ent.end_char
offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
if gold_entity == pred_entity:
correct = correct_by_label.get(ent_label, 0)
correct_by_label[ent_label] = correct + 1
else:
incorrect = incorrect_by_label.get(ent_label, 0)
incorrect_by_label[ent_label] = incorrect + 1
if error_analysis:
print(ent.text, "in", doc)
print(
"Predicted",
pred_entity,
"should have been",
gold_entity,
)
print()
except Exception as e:
print("Error assessing accuracy", e)
acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label)
return acc, acc_by_label
def _measure_baselines(data, kb):
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
counts_d = dict()
random_correct_d = dict()
random_incorrect_d = dict()
oracle_correct_d = dict()
oracle_incorrect_d = dict()
prior_correct_d = dict()
prior_incorrect_d = dict()
docs = [d for d, g in data if len(d) > 0]
golds = [g for d, g in data if len(d) > 0]
for doc, gold in zip(docs, golds):
try:
correct_entries_per_article = dict()
for entity, kb_dict in gold.links.items():
start, end = entity
for gold_kb, value in kb_dict.items():
# only evaluating on positive examples
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
for ent in doc.ents:
label = ent.label_
start = ent.start_char
end = ent.end_char
offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
counts_d[label] = counts_d.get(label, 0) + 1
candidates = kb.get_candidates(ent.text)
oracle_candidate = ""
best_candidate = ""
random_candidate = ""
if candidates:
scores = []
for c in candidates:
scores.append(c.prior_prob)
if c.entity_ == gold_entity:
oracle_candidate = c.entity_
best_index = scores.index(max(scores))
best_candidate = candidates[best_index].entity_
random_candidate = random.choice(candidates).entity_
if gold_entity == best_candidate:
prior_correct_d[label] = prior_correct_d.get(label, 0) + 1
else:
prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1
if gold_entity == random_candidate:
random_correct_d[label] = random_correct_d.get(label, 0) + 1
else:
random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1
if gold_entity == oracle_candidate:
oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1
else:
oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1
except Exception as e:
print("Error assessing accuracy", e)
acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d)
acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d)
acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d)
return (
counts_d,
acc_rand,
acc_rand_d,
acc_prior,
acc_prior_d,
acc_oracle,
acc_oracle_d,
dev_baseline_accuracies = measure_baselines(
dev_data, kb
)
logger.info("Dev Baseline Accuracies:")
logger.info(dev_baseline_accuracies.report_accuracy("random"))
logger.info(dev_baseline_accuracies.report_accuracy("prior"))
logger.info(dev_baseline_accuracies.report_accuracy("oracle"))
def _offset(start, end):
return "{}_{}".format(start, end)
for itn in range(epochs):
random.shuffle(train_data)
losses = {}
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
batchnr = 0
with nlp.disable_pipes(*other_pipes):
for batch in batches:
try:
docs, golds = zip(*batch)
nlp.update(
docs=docs,
golds=golds,
sgd=optimizer,
drop=dropout,
losses=losses,
)
batchnr += 1
except Exception as e:
logger.error("Error updating batch:" + str(e))
if batchnr > 0:
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
measure_performance(dev_data, kb, el_pipe)
def calculate_acc(correct_by_label, incorrect_by_label):
acc_by_label = dict()
total_correct = 0
total_incorrect = 0
all_keys = set()
all_keys.update(correct_by_label.keys())
all_keys.update(incorrect_by_label.keys())
for label in sorted(all_keys):
correct = correct_by_label.get(label, 0)
incorrect = incorrect_by_label.get(label, 0)
total_correct += correct
total_incorrect += incorrect
if correct == incorrect == 0:
acc_by_label[label] = 0
else:
acc_by_label[label] = correct / (correct + incorrect)
acc = 0
if not (total_correct == total_incorrect == 0):
acc = total_correct / (total_correct + total_incorrect)
return acc, acc_by_label
# STEP 4: measure the performance of our trained pipe on an independent dev set
logger.info("STEP 4: performance measurement of Entity Linking pipe")
measure_performance(dev_data, kb, el_pipe)
# STEP 5: apply the EL pipe on a toy example
logger.info("STEP 5: applying Entity Linking to toy example")
run_el_toy_example(nlp=nlp)
if output_dir:
# STEP 6: write the NLP pipeline (including entity linker) to file
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
nlp.to_disk(nlp_output_dir)
logger.info("Done!")
def check_kb(kb):
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
candidates = kb.get_candidates(mention)
print("generating candidates for " + mention + " :")
logger.info("generating candidates for " + mention + " :")
for c in candidates:
print(
" ",
c.prior_prob,
logger.info(" ".join[
str(c.prior_prob),
c.alias_,
"-->",
c.entity_ + " (freq=" + str(c.entity_freq) + ")",
)
print()
c.entity_ + " (freq=" + str(c.entity_freq) + ")"
])
def run_el_toy_example(nlp):
@ -421,11 +176,11 @@ def run_el_toy_example(nlp):
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
)
doc = nlp(text)
print(text)
logger.info(text)
for ent in doc.ents:
print(" ent", ent.text, ent.label_, ent.kb_id_)
print()
logger.info(" ".join(["ent", ent.text, ent.label_, ent.kb_id_]))
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
plac.call(main)

View File

@ -5,6 +5,9 @@ import re
import bz2
import csv
import datetime
import logging
from bin.wiki_entity_linking import LOG_FORMAT
"""
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
@ -13,6 +16,9 @@ Write these results to file for downstream KB and training data generation.
map_alias_to_link = dict()
logger = logging.getLogger(__name__)
# these will/should be matched ignoring case
wiki_namespaces = [
"b",
@ -116,10 +122,6 @@ for ns in wiki_namespaces:
ns_regex = re.compile(ns_regex, re.IGNORECASE)
def now():
return datetime.datetime.now()
def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
"""
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
@ -131,7 +133,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
cnt = 0
while line and (not limit or cnt < limit):
if cnt % 25000000 == 0:
print(now(), "processed", cnt, "lines of Wikipedia XML dump")
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
clean_line = line.strip().decode("utf-8")
aliases, entities, normalizations = get_wp_links(clean_line)
@ -141,7 +143,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
line = file.readline()
cnt += 1
print(now(), "processed", cnt, "lines of Wikipedia XML dump")
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
# write all aliases and their entities and count occurrences to file
with prior_prob_output.open("w", encoding="utf8") as outputfile: