calculate entity raw counts offline to speed up KB construction

This commit is contained in:
svlandeg 2019-04-30 11:39:42 +02:00
parent 19e8f339cb
commit 653b7d9c87
1 changed files with 64 additions and 27 deletions

View File

@ -1,23 +1,25 @@
# coding: utf-8
from __future__ import unicode_literals
from spacy.vocab import Vocab
"""
Demonstrate how to build a knowledge base from WikiData and run an Entity Linking algorithm.
"""
import re
import csv
import json
import spacy
import datetime
import bz2
from spacy.kb import KnowledgeBase
from spacy.vocab import Vocab
# TODO: remove hardcoded paths
WIKIDATA_JSON = 'C:/Users/Sofie/Documents/data/wikidata/wikidata-20190304-all.json.bz2'
ENWIKI_DUMP = 'C:/Users/Sofie/Documents/data/wikipedia/enwiki-20190320-pages-articles-multistream.xml.bz2'
ENWIKI_INDEX = 'C:/Users/Sofie/Documents/data/wikipedia/enwiki-20190320-pages-articles-multistream-index.txt.bz2'
PRIOR_PROB = 'C:/Users/Sofie/Documents/data/wikipedia/prior_prob.csv'
ENTITY_COUNTS = 'C:/Users/Sofie/Documents/data/wikipedia/entity_freq.csv'
KB_FILE = 'C:/Users/Sofie/Documents/data/wikipedia/kb'
VOCAB_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/vocab'
@ -44,18 +46,30 @@ map_alias_to_link = dict()
def create_kb(vocab, max_entities_per_alias, min_occ, to_print=False):
kb = KnowledgeBase(vocab=vocab)
id_to_title = _read_wikidata_entities(limit=None)
title_to_id = {v: k for k, v in id_to_title.items()}
print()
print("1. _read_wikidata_entities", datetime.datetime.now())
print()
title_to_id = _read_wikidata_entities(limit=100000)
entity_list = list(id_to_title.keys())
title_list = [id_to_title[x] for x in entity_list]
entity_frequencies = _get_entity_frequencies(entities=title_list, to_print=False)
title_list = list(title_to_id.keys())
entity_list = [title_to_id[x] for x in title_list]
print()
print("2. _get_entity_frequencies", datetime.datetime.now())
print()
entity_frequencies = _get_entity_frequencies(entities=title_list)
print()
print("3. _add_entities", datetime.datetime.now())
print()
_add_entities(kb,
entities=entity_list,
probs=entity_frequencies,
to_print=to_print)
print()
print("4. _add_aliases", datetime.datetime.now())
print()
_add_aliases(kb,
title_to_id=title_to_id,
max_entities_per_alias=max_entities_per_alias,
@ -72,15 +86,26 @@ def create_kb(vocab, max_entities_per_alias, min_occ, to_print=False):
return kb
def _get_entity_frequencies(entities, to_print=False):
count_entities = [0 for _ in entities]
def _get_entity_frequencies(entities):
entity_to_count = dict()
with open(ENTITY_COUNTS, 'r', encoding='utf8') as csvfile:
csvreader = csv.reader(csvfile, delimiter='|')
# skip header
next(csvreader)
for row in csvreader:
entity_to_count[row[0]] = int(row[1])
return [entity_to_count.get(e, 0) for e in entities]
def _write_entity_counts(to_print=False):
entity_to_count = dict()
total_count = 0
with open(PRIOR_PROB, mode='r', encoding='utf8') as prior_file:
# skip header
prior_file.readline()
line = prior_file.readline()
# we can read this file sequentially, it's sorted by alias, and then by count
while line:
splits = line.replace('\n', "").split(sep='|')
@ -88,23 +113,26 @@ def _get_entity_frequencies(entities, to_print=False):
count = int(splits[1])
entity = splits[2]
if entity in entities:
index = entities.index(entity)
count_entities[index] = count_entities[index] + count
current_count = entity_to_count.get(entity, 0)
entity_to_count[entity] = current_count + count
total_count += count
line = prior_file.readline()
with open(ENTITY_COUNTS, mode='w', encoding='utf8') as entity_file:
entity_file.write("entity" + "|" + "count" + "\n")
for entity, count in entity_to_count.items():
entity_file.write(entity + "|" + str(count) + "\n")
if to_print:
for entity, count in zip(entities, count_entities):
for entity, count in entity_to_count.items():
print("Entity count:", entity, count)
print("Total count:", total_count)
return [x*100 / total_count for x in count_entities]
def _add_entities(kb, entities, probs, to_print=False):
# TODO: this should be a bulk method
for entity, prob in zip(entities, probs):
kb.add_entity(entity=entity, prob=prob)
@ -166,13 +194,13 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, to_print=Fals
def _read_wikidata_entities(limit=None, to_print=False):
""" Read the JSON wiki data and parse out the entities"""
""" Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. """
languages = {'en', 'de'}
prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected
site_filter = 'enwiki'
entity_dict = dict()
title_to_id = dict()
# parse appropriate fields - depending on what we need in the KB
parse_properties = False
@ -192,12 +220,12 @@ def _read_wikidata_entities(limit=None, to_print=False):
clean_line = clean_line[:-1]
if len(clean_line) > 1:
obj = json.loads(clean_line)
unique_id = obj["id"]
entry_type = obj["type"]
if unique_id[0] == 'Q' and entry_type == "item":
if entry_type == "item":
# filtering records on their properties
keep = False
claims = obj["claims"]
for prop, value_set in prop_filter.items():
claim_property = claims.get(prop, None)
@ -209,6 +237,8 @@ def _read_wikidata_entities(limit=None, to_print=False):
keep = True
if keep:
unique_id = obj["id"]
if to_print:
print("ID:", unique_id)
print("type:", entry_type)
@ -225,9 +255,10 @@ def _read_wikidata_entities(limit=None, to_print=False):
if parse_sitelinks:
site_value = obj["sitelinks"].get(site_filter, None)
if site_value:
site = site_value['title']
if to_print:
print(site_filter, ":", site_value['title'])
entity_dict[unique_id] = site_value['title']
print(site_filter, ":", site)
title_to_id[site] = unique_id
if parse_labels:
labels = obj["labels"]
@ -262,7 +293,7 @@ def _read_wikidata_entities(limit=None, to_print=False):
line = file.readline()
cnt += 1
return entity_dict
return title_to_id
def _read_wikipedia_prior_probs():
@ -469,6 +500,7 @@ def capitalize_first(text):
if __name__ == "__main__":
to_create_prior_probs = False
to_create_entity_counts = False
to_create_kb = True
to_read_kb = False
@ -477,20 +509,25 @@ if __name__ == "__main__":
if to_create_prior_probs:
_read_wikipedia_prior_probs()
# STEP 2 : deduce entity frequencies from WP
# run only once !
if to_create_entity_counts:
_write_entity_counts()
if to_create_kb:
# STEP 2 : create KB
# STEP 3 : create KB
my_nlp = spacy.load('en_core_web_sm')
my_vocab = my_nlp.vocab
my_kb = create_kb(my_vocab, max_entities_per_alias=10, min_occ=5, to_print=False)
print("kb entities:", my_kb.get_size_entities())
print("kb aliases:", my_kb.get_size_aliases())
# STEP 3 : write KB to file
# STEP 4 : write KB to file
my_kb.dump(KB_FILE)
my_vocab.to_disk(VOCAB_DIR)
if to_read_kb:
# STEP 4 : read KB back in from file
# STEP 5 : read KB back in from file
my_vocab = Vocab()
my_vocab.from_disk(VOCAB_DIR)
my_kb = KnowledgeBase(vocab=my_vocab)
@ -507,5 +544,5 @@ if __name__ == "__main__":
print("alias:", c.alias_)
print("prior prob:", c.prior_prob)
# STEP 5: add KB to NLP pipeline
# STEP 6: add KB to NLP pipeline
# add_el(my_kb, nlp)