mirror of https://github.com/explosion/spaCy.git
431 lines
15 KiB
Python
431 lines
15 KiB
Python
# coding: utf-8
|
|
"""Script to take a previously created Knowledge Base and train an entity linking
|
|
pipeline. The provided KB directory should hold the kb, the original nlp object and
|
|
its vocab used to create the KB, and a few auxiliary files such as the entity definitions,
|
|
as created by the script `wikidata_create_kb`.
|
|
|
|
For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2
|
|
from https://dumps.wikimedia.org/enwiki/latest/
|
|
|
|
"""
|
|
from __future__ import unicode_literals
|
|
|
|
import random
|
|
import datetime
|
|
from pathlib import Path
|
|
import plac
|
|
|
|
from bin.wiki_entity_linking import training_set_creator
|
|
|
|
import spacy
|
|
from spacy.kb import KnowledgeBase
|
|
|
|
from spacy import Errors
|
|
from spacy.util import minibatch, compounding
|
|
|
|
|
|
def now():
|
|
return datetime.datetime.now()
|
|
|
|
|
|
@plac.annotations(
|
|
dir_kb=("Directory with KB, NLP and related files", "positional", None, Path),
|
|
output_dir=("Output directory", "option", "o", Path),
|
|
loc_training=("Location to training data", "option", "k", Path),
|
|
wp_xml=("Path to the downloaded Wikipedia XML dump.", "option", "w", Path),
|
|
epochs=("Number of training iterations (default 10)", "option", "e", int),
|
|
dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float),
|
|
lr=("Learning rate (default 0.005)", "option", "n", float),
|
|
l2=("L2 regularization", "option", "r", float),
|
|
train_inst=("# training instances (default 90% of all)", "option", "t", int),
|
|
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
|
|
limit=("Optional threshold to limit lines read from WP dump", "option", "l", int),
|
|
)
|
|
def main(
|
|
dir_kb,
|
|
output_dir=None,
|
|
loc_training=None,
|
|
wp_xml=None,
|
|
epochs=10,
|
|
dropout=0.5,
|
|
lr=0.005,
|
|
l2=1e-6,
|
|
train_inst=None,
|
|
dev_inst=None,
|
|
limit=None,
|
|
):
|
|
print(now(), "Creating Entity Linker with Wikipedia and WikiData")
|
|
print()
|
|
|
|
# STEP 0: set up IO
|
|
if output_dir and not output_dir.exists():
|
|
output_dir.mkdir()
|
|
|
|
# STEP 1 : load the NLP object
|
|
nlp_dir = dir_kb / "nlp"
|
|
print(now(), "STEP 1: loading model from", nlp_dir)
|
|
nlp = spacy.load(nlp_dir)
|
|
|
|
# check that there is a NER component in the pipeline
|
|
if "ner" not in nlp.pipe_names:
|
|
raise ValueError(Errors.E152)
|
|
|
|
# STEP 2 : read the KB
|
|
print()
|
|
print(now(), "STEP 2: reading the KB from", dir_kb / "kb")
|
|
kb = KnowledgeBase(vocab=nlp.vocab)
|
|
kb.load_bulk(dir_kb / "kb")
|
|
|
|
# STEP 3: create a training dataset from WP
|
|
print()
|
|
if loc_training:
|
|
print(now(), "STEP 3: reading training dataset from", loc_training)
|
|
else:
|
|
if not wp_xml:
|
|
raise ValueError(Errors.E153)
|
|
|
|
if output_dir:
|
|
loc_training = output_dir / "training_data"
|
|
else:
|
|
loc_training = dir_kb / "training_data"
|
|
if not loc_training.exists():
|
|
loc_training.mkdir()
|
|
print(now(), "STEP 3: creating training dataset at", loc_training)
|
|
|
|
if limit is not None:
|
|
print("Warning: reading only", limit, "lines of Wikipedia dump.")
|
|
|
|
loc_entity_defs = dir_kb / "entity_defs.csv"
|
|
training_set_creator.create_training(
|
|
wikipedia_input=wp_xml,
|
|
entity_def_input=loc_entity_defs,
|
|
training_output=loc_training,
|
|
limit=limit,
|
|
)
|
|
|
|
# STEP 4: parse the training data
|
|
print()
|
|
print(now(), "STEP 4: parse the training & evaluation data")
|
|
|
|
# for training, get pos & neg instances that correspond to entries in the kb
|
|
print("Parsing training data, limit =", train_inst)
|
|
train_data = training_set_creator.read_training(
|
|
nlp=nlp, training_dir=loc_training, dev=False, limit=train_inst, kb=kb
|
|
)
|
|
|
|
print("Training on", len(train_data), "articles")
|
|
print()
|
|
|
|
print("Parsing dev testing data, limit =", dev_inst)
|
|
# for testing, get all pos instances, whether or not they are in the kb
|
|
dev_data = training_set_creator.read_training(
|
|
nlp=nlp, training_dir=loc_training, dev=True, limit=dev_inst, kb=None
|
|
)
|
|
|
|
print("Dev testing on", len(dev_data), "articles")
|
|
print()
|
|
|
|
# STEP 5: create and train the entity linking pipe
|
|
print()
|
|
print(now(), "STEP 5: training Entity Linking pipe")
|
|
|
|
el_pipe = nlp.create_pipe(
|
|
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name}
|
|
)
|
|
el_pipe.set_kb(kb)
|
|
nlp.add_pipe(el_pipe, last=True)
|
|
|
|
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "entity_linker"]
|
|
with nlp.disable_pipes(*other_pipes): # only train Entity Linking
|
|
optimizer = nlp.begin_training()
|
|
optimizer.learn_rate = lr
|
|
optimizer.L2 = l2
|
|
|
|
if not train_data:
|
|
print("Did not find any training data")
|
|
else:
|
|
for itn in range(epochs):
|
|
random.shuffle(train_data)
|
|
losses = {}
|
|
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
|
|
batchnr = 0
|
|
|
|
with nlp.disable_pipes(*other_pipes):
|
|
for batch in batches:
|
|
try:
|
|
docs, golds = zip(*batch)
|
|
nlp.update(
|
|
docs=docs,
|
|
golds=golds,
|
|
sgd=optimizer,
|
|
drop=dropout,
|
|
losses=losses,
|
|
)
|
|
batchnr += 1
|
|
except Exception as e:
|
|
print("Error updating batch:", e)
|
|
|
|
if batchnr > 0:
|
|
el_pipe.cfg["incl_context"] = True
|
|
el_pipe.cfg["incl_prior"] = True
|
|
dev_acc_context, _ = _measure_acc(dev_data, el_pipe)
|
|
losses["entity_linker"] = losses["entity_linker"] / batchnr
|
|
print(
|
|
"Epoch, train loss",
|
|
itn,
|
|
round(losses["entity_linker"], 2),
|
|
" / dev accuracy avg",
|
|
round(dev_acc_context, 3),
|
|
)
|
|
|
|
# STEP 6: measure the performance of our trained pipe on an independent dev set
|
|
print()
|
|
if len(dev_data):
|
|
print()
|
|
print(now(), "STEP 6: performance measurement of Entity Linking pipe")
|
|
print()
|
|
|
|
counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines(
|
|
dev_data, kb
|
|
)
|
|
print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
|
|
|
|
oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()]
|
|
print("dev accuracy oracle:", round(acc_o, 3), oracle_by_label)
|
|
|
|
random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()]
|
|
print("dev accuracy random:", round(acc_r, 3), random_by_label)
|
|
|
|
prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()]
|
|
print("dev accuracy prior:", round(acc_p, 3), prior_by_label)
|
|
|
|
# using only context
|
|
el_pipe.cfg["incl_context"] = True
|
|
el_pipe.cfg["incl_prior"] = False
|
|
dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe)
|
|
context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()]
|
|
print("dev accuracy context:", round(dev_acc_context, 3), context_by_label)
|
|
|
|
# measuring combined accuracy (prior + context)
|
|
el_pipe.cfg["incl_context"] = True
|
|
el_pipe.cfg["incl_prior"] = True
|
|
dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe)
|
|
combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()]
|
|
print("dev accuracy prior+context:", round(dev_acc_combo, 3), combo_by_label)
|
|
|
|
# STEP 7: apply the EL pipe on a toy example
|
|
print()
|
|
print(now(), "STEP 7: applying Entity Linking to toy example")
|
|
print()
|
|
run_el_toy_example(nlp=nlp)
|
|
|
|
# STEP 8: write the NLP pipeline (including entity linker) to file
|
|
if output_dir:
|
|
print()
|
|
nlp_loc = output_dir / "nlp"
|
|
print(now(), "STEP 8: Writing trained NLP to", nlp_loc)
|
|
nlp.to_disk(nlp_loc)
|
|
print()
|
|
|
|
print()
|
|
print(now(), "Done!")
|
|
|
|
|
|
def _measure_acc(data, el_pipe=None, error_analysis=False):
|
|
# If the docs in the data require further processing with an entity linker, set el_pipe
|
|
correct_by_label = dict()
|
|
incorrect_by_label = dict()
|
|
|
|
docs = [d for d, g in data if len(d) > 0]
|
|
if el_pipe is not None:
|
|
docs = list(el_pipe.pipe(docs))
|
|
golds = [g for d, g in data if len(d) > 0]
|
|
|
|
for doc, gold in zip(docs, golds):
|
|
try:
|
|
correct_entries_per_article = dict()
|
|
for entity, kb_dict in gold.links.items():
|
|
start, end = entity
|
|
# only evaluating on positive examples
|
|
for gold_kb, value in kb_dict.items():
|
|
if value:
|
|
offset = _offset(start, end)
|
|
correct_entries_per_article[offset] = gold_kb
|
|
|
|
for ent in doc.ents:
|
|
ent_label = ent.label_
|
|
pred_entity = ent.kb_id_
|
|
start = ent.start_char
|
|
end = ent.end_char
|
|
offset = _offset(start, end)
|
|
gold_entity = correct_entries_per_article.get(offset, None)
|
|
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
|
if gold_entity is not None:
|
|
if gold_entity == pred_entity:
|
|
correct = correct_by_label.get(ent_label, 0)
|
|
correct_by_label[ent_label] = correct + 1
|
|
else:
|
|
incorrect = incorrect_by_label.get(ent_label, 0)
|
|
incorrect_by_label[ent_label] = incorrect + 1
|
|
if error_analysis:
|
|
print(ent.text, "in", doc)
|
|
print(
|
|
"Predicted",
|
|
pred_entity,
|
|
"should have been",
|
|
gold_entity,
|
|
)
|
|
print()
|
|
|
|
except Exception as e:
|
|
print("Error assessing accuracy", e)
|
|
|
|
acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label)
|
|
return acc, acc_by_label
|
|
|
|
|
|
def _measure_baselines(data, kb):
|
|
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
|
|
counts_d = dict()
|
|
|
|
random_correct_d = dict()
|
|
random_incorrect_d = dict()
|
|
|
|
oracle_correct_d = dict()
|
|
oracle_incorrect_d = dict()
|
|
|
|
prior_correct_d = dict()
|
|
prior_incorrect_d = dict()
|
|
|
|
docs = [d for d, g in data if len(d) > 0]
|
|
golds = [g for d, g in data if len(d) > 0]
|
|
|
|
for doc, gold in zip(docs, golds):
|
|
try:
|
|
correct_entries_per_article = dict()
|
|
for entity, kb_dict in gold.links.items():
|
|
start, end = entity
|
|
for gold_kb, value in kb_dict.items():
|
|
# only evaluating on positive examples
|
|
if value:
|
|
offset = _offset(start, end)
|
|
correct_entries_per_article[offset] = gold_kb
|
|
|
|
for ent in doc.ents:
|
|
label = ent.label_
|
|
start = ent.start_char
|
|
end = ent.end_char
|
|
offset = _offset(start, end)
|
|
gold_entity = correct_entries_per_article.get(offset, None)
|
|
|
|
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
|
if gold_entity is not None:
|
|
counts_d[label] = counts_d.get(label, 0) + 1
|
|
candidates = kb.get_candidates(ent.text)
|
|
oracle_candidate = ""
|
|
best_candidate = ""
|
|
random_candidate = ""
|
|
if candidates:
|
|
scores = []
|
|
|
|
for c in candidates:
|
|
scores.append(c.prior_prob)
|
|
if c.entity_ == gold_entity:
|
|
oracle_candidate = c.entity_
|
|
|
|
best_index = scores.index(max(scores))
|
|
best_candidate = candidates[best_index].entity_
|
|
random_candidate = random.choice(candidates).entity_
|
|
|
|
if gold_entity == best_candidate:
|
|
prior_correct_d[label] = prior_correct_d.get(label, 0) + 1
|
|
else:
|
|
prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1
|
|
|
|
if gold_entity == random_candidate:
|
|
random_correct_d[label] = random_correct_d.get(label, 0) + 1
|
|
else:
|
|
random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1
|
|
|
|
if gold_entity == oracle_candidate:
|
|
oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1
|
|
else:
|
|
oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1
|
|
|
|
except Exception as e:
|
|
print("Error assessing accuracy", e)
|
|
|
|
acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d)
|
|
acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d)
|
|
acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d)
|
|
|
|
return (
|
|
counts_d,
|
|
acc_rand,
|
|
acc_rand_d,
|
|
acc_prior,
|
|
acc_prior_d,
|
|
acc_oracle,
|
|
acc_oracle_d,
|
|
)
|
|
|
|
|
|
def _offset(start, end):
|
|
return "{}_{}".format(start, end)
|
|
|
|
|
|
def calculate_acc(correct_by_label, incorrect_by_label):
|
|
acc_by_label = dict()
|
|
total_correct = 0
|
|
total_incorrect = 0
|
|
all_keys = set()
|
|
all_keys.update(correct_by_label.keys())
|
|
all_keys.update(incorrect_by_label.keys())
|
|
for label in sorted(all_keys):
|
|
correct = correct_by_label.get(label, 0)
|
|
incorrect = incorrect_by_label.get(label, 0)
|
|
total_correct += correct
|
|
total_incorrect += incorrect
|
|
if correct == incorrect == 0:
|
|
acc_by_label[label] = 0
|
|
else:
|
|
acc_by_label[label] = correct / (correct + incorrect)
|
|
acc = 0
|
|
if not (total_correct == total_incorrect == 0):
|
|
acc = total_correct / (total_correct + total_incorrect)
|
|
return acc, acc_by_label
|
|
|
|
|
|
def check_kb(kb):
|
|
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
|
|
candidates = kb.get_candidates(mention)
|
|
|
|
print("generating candidates for " + mention + " :")
|
|
for c in candidates:
|
|
print(
|
|
" ",
|
|
c.prior_prob,
|
|
c.alias_,
|
|
"-->",
|
|
c.entity_ + " (freq=" + str(c.entity_freq) + ")",
|
|
)
|
|
print()
|
|
|
|
|
|
def run_el_toy_example(nlp):
|
|
text = (
|
|
"In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
|
|
"Douglas reminds us to always bring our towel, even in China or Brazil. "
|
|
"The main character in Doug's novel is the man Arthur Dent, "
|
|
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
|
|
)
|
|
doc = nlp(text)
|
|
print(text)
|
|
for ent in doc.ents:
|
|
print(" ent", ent.text, ent.label_, ent.kb_id_)
|
|
print()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
plac.call(main)
|