# coding: utf-8
from __future__ import unicode_literals
import re
import csv
import bz2
import datetime
from . import wikipedia_processor as wp
"""
Process Wikipedia interlinks to generate a training dataset for the EL algorithm
"""
ENTITY_FILE = "gold_entities.csv"
def create_training(kb, entity_input, training_output):
if not kb:
raise ValueError("kb should be defined")
# nlp = spacy.load('en_core_web_sm')
wp_to_id = _get_entity_to_id(entity_input)
_process_wikipedia_texts(kb, wp_to_id, training_output, limit=100000000) # TODO: full dataset
def _get_entity_to_id(entity_input):
entity_to_id = dict()
with open(entity_input, 'r', encoding='utf8') as csvfile:
csvreader = csv.reader(csvfile, delimiter='|')
# skip header
next(csvreader)
for row in csvreader:
entity_to_id[row[0]] = row[1]
return entity_to_id
def _process_wikipedia_texts(kb, wp_to_id, training_output, limit=None):
"""
Read the XML wikipedia data to parse out training data:
raw text data + positive and negative instances
"""
title_regex = re.compile(r'(?<=
).*(?=)')
id_regex = re.compile(r'(?<=)\d*(?=)')
read_ids = set()
entityfile_loc = training_output + "/" + ENTITY_FILE
with open(entityfile_loc, mode="w", encoding='utf8') as entityfile:
# write entity training header file
_write_training_entity(outputfile=entityfile,
article_id="article_id",
alias="alias",
entity="entity",
correct="correct")
with bz2.open(wp.ENWIKI_DUMP, mode='rb') as file:
line = file.readline()
cnt = 0
article_text = ""
article_title = None
article_id = None
reading_text = False
reading_revision = False
while line and (not limit or cnt < limit):
if cnt % 1000000 == 0:
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
clean_line = line.strip().decode("utf-8")
# print(clean_line)
if clean_line == "":
reading_revision = True
elif clean_line == "":
reading_revision = False
# Start reading new page
if clean_line == "":
article_text = ""
article_title = None
article_id = None
# finished reading this page
elif clean_line == "":
if article_id:
try:
_process_wp_text(kb, wp_to_id, entityfile, article_id, article_text.strip(), training_output)
# on a previous run, an error occurred after 46M lines and 2h
except Exception as e:
print("Error processing article", article_id, article_title, e)
else:
print("Done processing a page, but couldn't find an article_id ?")
print(article_title)
print(article_text)
article_text = ""
article_title = None
article_id = None
reading_text = False
reading_revision = False
# start reading text within a page
if ").*(?=", entity)
candidates = kb.get_candidates(alias)
# as training data, we only store entities that are sufficiently ambiguous
if len(candidates) > 1:
_write_training_article(article_id=article_id, clean_text=clean_text, training_output=training_output)
# print("alias", alias)
# print all incorrect candidates
for c in candidates:
if entity != c.entity_:
_write_training_entity(outputfile=entityfile,
article_id=article_id,
alias=alias,
entity=c.entity_,
correct="0")
# print the one correct candidate
_write_training_entity(outputfile=entityfile,
article_id=article_id,
alias=alias,
entity=entity,
correct="1")
# print("gold entity", entity)
# print()
# _run_ner_depr(nlp, clean_text, article_dict)
# print()
info_regex = re.compile(r'{[^{]*?}')
interwiki_regex = re.compile(r'\[\[([^|]*?)]]')
interwiki_2_regex = re.compile(r'\[\[[^|]*?\|([^|]*?)]]')
htlm_regex = re.compile(r'<!--[^!]*-->')
category_regex = re.compile(r'\[\[Category:[^\[]*]]')
file_regex = re.compile(r'\[\[File:[^[\]]+]]')
ref_regex = re.compile(r'<ref.*?>') # non-greedy
ref_2_regex = re.compile(r'</ref.*?>') # non-greedy
def _get_clean_wp_text(article_text):
clean_text = article_text.strip()
# remove bolding & italic markup
clean_text = clean_text.replace('\'\'\'', '')
clean_text = clean_text.replace('\'\'', '')
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
try_again = True
previous_length = len(clean_text)
while try_again:
clean_text = info_regex.sub('', clean_text) # non-greedy match excluding a nested {
if len(clean_text) < previous_length:
try_again = True
else:
try_again = False
previous_length = len(clean_text)
# remove simple interwiki links (no alternative name)
clean_text = interwiki_regex.sub(r'\1', clean_text)
# remove simple interwiki links by picking the alternative name
clean_text = interwiki_2_regex.sub(r'\1', clean_text)
# remove HTML comments
clean_text = htlm_regex.sub('', clean_text)
# remove Category and File statements
clean_text = category_regex.sub('', clean_text)
clean_text = file_regex.sub('', clean_text)
# remove multiple =
while '==' in clean_text:
clean_text = clean_text.replace("==", "=")
clean_text = clean_text.replace(". =", ".")
clean_text = clean_text.replace(" = ", ". ")
clean_text = clean_text.replace("= ", ".")
clean_text = clean_text.replace(" =", "")
# remove refs (non-greedy match)
clean_text = ref_regex.sub('', clean_text)
clean_text = ref_2_regex.sub('', clean_text)
# remove additional wikiformatting
clean_text = re.sub(r'<blockquote>', '', clean_text)
clean_text = re.sub(r'</blockquote>', '', clean_text)
# change special characters back to normal ones
clean_text = clean_text.replace(r'<', '<')
clean_text = clean_text.replace(r'>', '>')
clean_text = clean_text.replace(r'"', '"')
clean_text = clean_text.replace(r' ', ' ')
clean_text = clean_text.replace(r'&', '&')
# remove multiple spaces
while ' ' in clean_text:
clean_text = clean_text.replace(' ', ' ')
return clean_text.strip()
def _write_training_article(article_id, clean_text, training_output):
file_loc = training_output + "/" + str(article_id) + ".txt"
with open(file_loc, mode='w', encoding='utf8') as outputfile:
outputfile.write(clean_text)
def _write_training_entity(outputfile, article_id, alias, entity, correct):
outputfile.write(article_id + "|" + alias + "|" + entity + "|" + correct + "\n")
def read_training_entities(training_output, collect_correct=True, collect_incorrect=False):
entityfile_loc = training_output + "/" + ENTITY_FILE
incorrect_entries_per_article = dict()
correct_entries_per_article = dict()
with open(entityfile_loc, mode='r', encoding='utf8') as file:
for line in file:
fields = line.replace('\n', "").split(sep='|')
article_id = fields[0]
alias = fields[1]
entity = fields[2]
correct = fields[3]
if correct == "1" and collect_correct:
entry_dict = correct_entries_per_article.get(article_id, dict())
if alias in entry_dict:
raise ValueError("Found alias", alias, "multiple times for article", article_id, "in", ENTITY_FILE)
entry_dict[alias] = entity
correct_entries_per_article[article_id] = entry_dict
if correct == "0" and collect_incorrect:
entry_dict = incorrect_entries_per_article.get(article_id, dict())
entities = entry_dict.get(alias, set())
entities.add(entity)
entry_dict[alias] = entities
incorrect_entries_per_article[article_id] = entry_dict
return correct_entries_per_article, incorrect_entries_per_article