diff --git a/bin/wiki_entity_linking/kb_creator.py b/bin/wiki_entity_linking/kb_creator.py
index 8d293a0a1..bd82e5b4e 100644
--- a/bin/wiki_entity_linking/kb_creator.py
+++ b/bin/wiki_entity_linking/kb_creator.py
@@ -1,31 +1,31 @@
# coding: utf-8
from __future__ import unicode_literals
-from bin.wiki_entity_linking.train_descriptions import EntityEncoder
+from .train_descriptions import EntityEncoder
+from . import wikidata_processor as wd, wikipedia_processor as wp
from spacy.kb import KnowledgeBase
import csv
import datetime
-from bin.wiki_entity_linking import wikidata_processor as wd, wikipedia_processor as wp
-INPUT_DIM = 300 # dimension of pre-trained vectors
-DESC_WIDTH = 64
+INPUT_DIM = 300 # dimension of pre-trained input vectors
+DESC_WIDTH = 64 # dimension of output entity vectors
def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
entity_def_output, entity_descr_output,
- count_input, prior_prob_input, to_print=False):
+ count_input, prior_prob_input, wikidata_input):
# Create the knowledge base from Wikidata entries
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
# disable this part of the pipeline when rerunning the KB generation from preprocessed files
- read_raw_data = False
+ read_raw_data = True
if read_raw_data:
print()
print(" * _read_wikidata_entities", datetime.datetime.now())
- title_to_id, id_to_descr = wd.read_wikidata_entities_json(limit=None)
+ title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input)
# write the title-ID and ID-description mappings to file
_write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr)
@@ -40,7 +40,7 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
print()
entity_frequencies = wp.get_all_frequencies(count_input=count_input)
- # filter the entities for in the KB by frequency, because there's just too much data otherwise
+ # filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
filtered_title_to_id = dict()
entity_list = list()
description_list = list()
@@ -60,11 +60,10 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
print()
print(" * train entity encoder", datetime.datetime.now())
print()
-
encoder = EntityEncoder(nlp, INPUT_DIM, DESC_WIDTH)
encoder.train(description_list=description_list, to_print=True)
- print()
+ print()
print(" * get entity embeddings", datetime.datetime.now())
print()
embeddings = encoder.apply_encoder(description_list)
@@ -80,12 +79,10 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
max_entities_per_alias=max_entities_per_alias, min_occ=min_occ,
prior_prob_input=prior_prob_input)
- if to_print:
- print()
- print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases())
+ print()
+ print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases())
print("done with kb", datetime.datetime.now())
-
return kb
@@ -94,6 +91,7 @@ def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_
id_file.write("WP_title" + "|" + "WD_id" + "\n")
for title, qid in title_to_id.items():
id_file.write(title + "|" + str(qid) + "\n")
+
with open(entity_descr_output, mode='w', encoding='utf8') as descr_file:
descr_file.write("WD_id" + "|" + "description" + "\n")
for qid, descr in id_to_descr.items():
@@ -108,7 +106,6 @@ def get_entity_to_id(entity_def_output):
next(csvreader)
for row in csvreader:
entity_to_id[row[0]] = row[1]
-
return entity_to_id
@@ -120,16 +117,12 @@ def _get_id_to_description(entity_descr_output):
next(csvreader)
for row in csvreader:
id_to_desc[row[0]] = row[1]
-
return id_to_desc
-def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input, to_print=False):
+def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input):
wp_titles = title_to_id.keys()
- if to_print:
- print("wp titles:", wp_titles)
-
# adding aliases with prior probabilities
# we can read this file sequentially, it's sorted by alias, and then by count
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
@@ -176,6 +169,3 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
line = prior_file.readline()
- if to_print:
- print("added", kb.get_size_aliases(), "aliases:", kb.get_alias_strings())
-
diff --git a/bin/wiki_entity_linking/train_descriptions.py b/bin/wiki_entity_linking/train_descriptions.py
index 82db582dc..948a0e2d1 100644
--- a/bin/wiki_entity_linking/train_descriptions.py
+++ b/bin/wiki_entity_linking/train_descriptions.py
@@ -32,8 +32,6 @@ class EntityEncoder:
if self.encoder is None:
raise ValueError("Can not apply encoder before training it")
- print("Encoding", len(description_list), "entities")
-
batch_size = 100000
start = 0
@@ -48,13 +46,11 @@ class EntityEncoder:
start = start + batch_size
stop = min(stop + batch_size, len(description_list))
- print("encoded :", len(encodings))
return encodings
def train(self, description_list, to_print=False):
processed, loss = self._train_model(description_list)
-
if to_print:
print("Trained on", processed, "entities across", self.EPOCHS, "epochs")
print("Final loss:", loss)
@@ -111,15 +107,12 @@ class EntityEncoder:
Affine(hidden_with, orig_width)
)
self.model = self.encoder >> zero_init(Affine(orig_width, hidden_with, drop_factor=0.0))
-
self.sgd = create_default_optimizer(self.model.ops)
def _update(self, vectors):
predictions, bp_model = self.model.begin_update(np.asarray(vectors), drop=self.DROP)
-
loss, d_scores = self._get_loss(scores=predictions, golds=np.asarray(vectors))
bp_model(d_scores, sgd=self.sgd)
-
return loss / len(vectors)
@staticmethod
diff --git a/bin/wiki_entity_linking/training_set_creator.py b/bin/wiki_entity_linking/training_set_creator.py
index 90df5d9fc..eb9f8af78 100644
--- a/bin/wiki_entity_linking/training_set_creator.py
+++ b/bin/wiki_entity_linking/training_set_creator.py
@@ -18,23 +18,21 @@ Gold-standard entities are stored in one file in standoff format (by character o
ENTITY_FILE = "gold_entities_1000000.csv" # use this file for faster processing
-def create_training(entity_def_input, training_output):
+def create_training(wikipedia_input, entity_def_input, training_output):
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
- _process_wikipedia_texts(wp_to_id, training_output, limit=None)
+ _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None)
-def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
+def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None):
"""
Read the XML wikipedia data to parse out training data:
raw text data + positive instances
"""
-
title_regex = re.compile(r'(?<=
).*(?=)')
id_regex = re.compile(r'(?<=)\d*(?=)')
read_ids = set()
-
- entityfile_loc = training_output + "/" + ENTITY_FILE
+ entityfile_loc = training_output / ENTITY_FILE
with open(entityfile_loc, mode="w", encoding='utf8') as entityfile:
# write entity training header file
_write_training_entity(outputfile=entityfile,
@@ -44,7 +42,7 @@ def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
start="start",
end="end")
- with bz2.open(wp.ENWIKI_DUMP, mode='rb') as file:
+ with bz2.open(wikipedia_input, mode='rb') as file:
line = file.readline()
cnt = 0
article_text = ""
@@ -104,7 +102,7 @@ def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
print("Found duplicate article ID", article_id, clean_line) # This should never happen ...
read_ids.add(article_id)
- # read the title of this article (outside the revision portion of the document)
+ # read the title of this article (outside the revision portion of the document)
if not reading_revision:
titles = title_regex.search(clean_line)
if titles:
@@ -134,7 +132,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
# get the raw text without markup etc, keeping only interwiki links
clean_text = _get_clean_wp_text(text)
- # read the text char by char to get the right offsets of the interwiki links
+ # read the text char by char to get the right offsets for the interwiki links
final_text = ""
open_read = 0
reading_text = True
@@ -274,7 +272,7 @@ def _get_clean_wp_text(article_text):
def _write_training_article(article_id, clean_text, training_output):
- file_loc = training_output + "/" + str(article_id) + ".txt"
+ file_loc = training_output / str(article_id) + ".txt"
with open(file_loc, mode='w', encoding='utf8') as outputfile:
outputfile.write(clean_text)
@@ -289,11 +287,10 @@ def is_dev(article_id):
def read_training(nlp, training_dir, dev, limit):
# This method provides training examples that correspond to the entity annotations found by the nlp object
-
- entityfile_loc = training_dir + "/" + ENTITY_FILE
+ entityfile_loc = training_dir / ENTITY_FILE
data = []
- # we assume the data is written sequentially
+ # assume the data is written sequentially, so we can reuse the article docs
current_article_id = None
current_doc = None
ents_by_offset = dict()
@@ -347,10 +344,10 @@ def read_training(nlp, training_dir, dev, limit):
gold_end = int(end) - found_ent.sent.start_char
gold_entities = list()
gold_entities.append((gold_start, gold_end, wp_title))
- gold = GoldParse(doc=current_doc, links=gold_entities)
+ gold = GoldParse(doc=sent, links=gold_entities)
data.append((sent, gold))
total_entities += 1
- if len(data) % 500 == 0:
+ if len(data) % 2500 == 0:
print(" -read", total_entities, "entities")
print(" -read", total_entities, "entities")
diff --git a/bin/wiki_entity_linking/wikidata_processor.py b/bin/wiki_entity_linking/wikidata_processor.py
index 85d3d8488..a32a0769a 100644
--- a/bin/wiki_entity_linking/wikidata_processor.py
+++ b/bin/wiki_entity_linking/wikidata_processor.py
@@ -5,17 +5,15 @@ import bz2
import json
import datetime
-# TODO: remove hardcoded paths
-WIKIDATA_JSON = 'C:/Users/Sofie/Documents/data/wikidata/wikidata-20190304-all.json.bz2'
-
-def read_wikidata_entities_json(limit=None, to_print=False):
+def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
# Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
+ # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
lang = 'en'
site_filter = 'enwiki'
- # filter currently disabled to get ALL data
+ # properties filter (currently disabled to get ALL data)
prop_filter = dict()
# prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected
@@ -30,7 +28,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
parse_aliases = False
parse_claims = False
- with bz2.open(WIKIDATA_JSON, mode='rb') as file:
+ with bz2.open(wikidata_file, mode='rb') as file:
line = file.readline()
cnt = 0
while line and (not limit or cnt < limit):
diff --git a/bin/wiki_entity_linking/wikipedia_processor.py b/bin/wiki_entity_linking/wikipedia_processor.py
index d957fc58c..c02e472bc 100644
--- a/bin/wiki_entity_linking/wikipedia_processor.py
+++ b/bin/wiki_entity_linking/wikipedia_processor.py
@@ -11,11 +11,6 @@ Process a Wikipedia dump to calculate entity frequencies and prior probabilities
Write these results to file for downstream KB and training data generation.
"""
-
-# TODO: remove hardcoded paths
-ENWIKI_DUMP = 'C:/Users/Sofie/Documents/data/wikipedia/enwiki-20190320-pages-articles-multistream.xml.bz2'
-ENWIKI_INDEX = 'C:/Users/Sofie/Documents/data/wikipedia/enwiki-20190320-pages-articles-multistream-index.txt.bz2'
-
map_alias_to_link = dict()
# these will/should be matched ignoring case
@@ -46,15 +41,13 @@ for ns in wiki_namespaces:
ns_regex = re.compile(ns_regex, re.IGNORECASE)
-def read_wikipedia_prior_probs(prior_prob_output):
+def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
"""
- Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities
- The full file takes about 2h to parse 1100M lines (update printed every 5M lines).
- It works relatively fast because we don't care about which article we parsed the interwiki from,
- we just process line by line.
+ Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
+ The full file takes about 2h to parse 1100M lines.
+ It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
"""
-
- with bz2.open(ENWIKI_DUMP, mode='rb') as file:
+ with bz2.open(wikipedia_input, mode='rb') as file:
line = file.readline()
cnt = 0
while line:
@@ -70,7 +63,7 @@ def read_wikipedia_prior_probs(prior_prob_output):
line = file.readline()
cnt += 1
- # write all aliases and their entities and occurrences to file
+ # write all aliases and their entities and count occurrences to file
with open(prior_prob_output, mode='w', encoding='utf8') as outputfile:
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
@@ -108,7 +101,7 @@ def get_wp_links(text):
if ns_regex.match(match):
pass # ignore namespaces at the beginning of the string
- # this is a simple link, with the alias the same as the mention
+ # this is a simple [[link]], with the alias the same as the mention
elif "|" not in match:
aliases.append(match)
entities.append(match)
diff --git a/examples/pipeline/wikidata_entity_linking.py b/examples/pipeline/wikidata_entity_linking.py
index c282c7262..aa1c00996 100644
--- a/examples/pipeline/wikidata_entity_linking.py
+++ b/examples/pipeline/wikidata_entity_linking.py
@@ -2,35 +2,45 @@
from __future__ import unicode_literals
import random
-
-from spacy.util import minibatch, compounding
+import datetime
+from pathlib import Path
from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp
from bin.wiki_entity_linking.kb_creator import DESC_WIDTH
import spacy
from spacy.kb import KnowledgeBase
-import datetime
+from spacy.util import minibatch, compounding
"""
Demonstrate how to build a knowledge base from WikiData and run an Entity Linking algorithm.
"""
-PRIOR_PROB = 'C:/Users/Sofie/Documents/data/wikipedia/prior_prob.csv'
-ENTITY_COUNTS = 'C:/Users/Sofie/Documents/data/wikipedia/entity_freq.csv'
-ENTITY_DEFS = 'C:/Users/Sofie/Documents/data/wikipedia/entity_defs.csv'
-ENTITY_DESCR = 'C:/Users/Sofie/Documents/data/wikipedia/entity_descriptions.csv'
+ROOT_DIR = Path("C:/Users/Sofie/Documents/data/")
+OUTPUT_DIR = ROOT_DIR / 'wikipedia'
+TRAINING_DIR = OUTPUT_DIR / 'training_data_nel'
-KB_FILE = 'C:/Users/Sofie/Documents/data/wikipedia/kb_1/kb'
-NLP_1_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_1'
-NLP_2_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_2'
+PRIOR_PROB = OUTPUT_DIR / 'prior_prob.csv'
+ENTITY_COUNTS = OUTPUT_DIR / 'entity_freq.csv'
+ENTITY_DEFS = OUTPUT_DIR / 'entity_defs.csv'
+ENTITY_DESCR = OUTPUT_DIR / 'entity_descriptions.csv'
-TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
+KB_FILE = OUTPUT_DIR / 'kb_1' / 'kb'
+NLP_1_DIR = OUTPUT_DIR / 'nlp_1'
+NLP_2_DIR = OUTPUT_DIR / 'nlp_2'
+# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
+WIKIDATA_JSON = ROOT_DIR / 'wikidata' / 'wikidata-20190304-all.json.bz2'
+
+# get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/
+ENWIKI_DUMP = ROOT_DIR / 'wikipedia' / 'enwiki-20190320-pages-articles-multistream.xml.bz2'
+
+# KB construction parameters
MAX_CANDIDATES = 10
MIN_ENTITY_FREQ = 20
MIN_PAIR_OCC = 5
+# model training parameters
EPOCHS = 10
DROPOUT = 0.1
LEARN_RATE = 0.005
@@ -38,6 +48,7 @@ L2 = 1e-6
def run_pipeline():
+ # set the appropriate booleans to define which parts of the pipeline should be re(run)
print("START", datetime.datetime.now())
print()
nlp_1 = spacy.load('en_core_web_lg')
@@ -67,22 +78,19 @@ def run_pipeline():
to_write_nlp = False
to_read_nlp = False
- # STEP 1 : create prior probabilities from WP
- # run only once !
+ # STEP 1 : create prior probabilities from WP (run only once)
if to_create_prior_probs:
print("STEP 1: to_create_prior_probs", datetime.datetime.now())
- wp.read_wikipedia_prior_probs(prior_prob_output=PRIOR_PROB)
+ wp.read_wikipedia_prior_probs(wikipedia_input=ENWIKI_DUMP, prior_prob_output=PRIOR_PROB)
print()
- # STEP 2 : deduce entity frequencies from WP
- # run only once !
+ # STEP 2 : deduce entity frequencies from WP (run only once)
if to_create_entity_counts:
print("STEP 2: to_create_entity_counts", datetime.datetime.now())
wp.write_entity_counts(prior_prob_input=PRIOR_PROB, count_output=ENTITY_COUNTS, to_print=False)
print()
- # STEP 3 : create KB and write to file
- # run only once !
+ # STEP 3 : create KB and write to file (run only once)
if to_create_kb:
print("STEP 3a: to_create_kb", datetime.datetime.now())
kb_1 = kb_creator.create_kb(nlp_1,
@@ -93,7 +101,7 @@ def run_pipeline():
entity_descr_output=ENTITY_DESCR,
count_input=ENTITY_COUNTS,
prior_prob_input=PRIOR_PROB,
- to_print=False)
+ wikidata_input=WIKIDATA_JSON)
print("kb entities:", kb_1.get_size_entities())
print("kb aliases:", kb_1.get_size_aliases())
print()
@@ -121,7 +129,9 @@ def run_pipeline():
# STEP 5: create a training dataset from WP
if create_wp_training:
print("STEP 5: create training dataset", datetime.datetime.now())
- training_set_creator.create_training(entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
+ training_set_creator.create_training(wikipedia_input=ENWIKI_DUMP,
+ entity_def_input=ENTITY_DEFS,
+ training_output=TRAINING_DIR)
# STEP 6: create and train the entity linking pipe
el_pipe = nlp_2.create_pipe(name='entity_linker', config={})
@@ -136,7 +146,8 @@ def run_pipeline():
if train_pipe:
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
- train_limit = 25000
+ # define the size (nr of entities) of training and dev set
+ train_limit = 10000
dev_limit = 5000
train_data = training_set_creator.read_training(nlp=nlp_2,
@@ -157,7 +168,6 @@ def run_pipeline():
if not train_data:
print("Did not find any training data")
-
else:
for itn in range(EPOCHS):
random.shuffle(train_data)
@@ -196,7 +206,7 @@ def run_pipeline():
print()
counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2)
- print("dev counts:", sorted(counts))
+ print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()])
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()])
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()])
@@ -215,7 +225,6 @@ def run_pipeline():
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
print("dev acc context avg:", round(dev_acc_context, 3),
[(x, round(y, 3)) for x, y in dev_acc_context_dict.items()])
- print()
# reset for follow-up tests
el_pipe.context_weight = 1
@@ -227,7 +236,6 @@ def run_pipeline():
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
print()
run_el_toy_example(nlp=nlp_2)
- print()
# STEP 9: write the NLP pipeline (including entity linker) to file
if to_write_nlp:
@@ -400,26 +408,9 @@ def run_el_toy_example(nlp):
doc = nlp(text)
print(text)
for ent in doc.ents:
- print("ent", ent.text, ent.label_, ent.kb_id_)
+ print(" ent", ent.text, ent.label_, ent.kb_id_)
print()
- # Q4426480 is her husband
- text = "Ada Lovelace was the countess of Lovelace. She's known for her programming work on the analytical engine. "\
- "She loved her husband William King dearly. "
- doc = nlp(text)
- print(text)
- for ent in doc.ents:
- print("ent", ent.text, ent.label_, ent.kb_id_)
- print()
-
- # Q3568763 is her tutor
- text = "Ada Lovelace was the countess of Lovelace. She's known for her programming work on the analytical engine. "\
- "She was tutored by her favorite physics tutor William King."
- doc = nlp(text)
- print(text)
- for ent in doc.ents:
- print("ent", ent.text, ent.label_, ent.kb_id_)
-
if __name__ == "__main__":
run_pipeline()
diff --git a/spacy/kb.pxd b/spacy/kb.pxd
index 9c5a73d59..ccf150cd2 100644
--- a/spacy/kb.pxd
+++ b/spacy/kb.pxd
@@ -18,7 +18,6 @@ ctypedef vector[float_vec] float_matrix
# Object used by the Entity Linker that summarizes one entity-alias candidate combination.
cdef class Candidate:
-
cdef readonly KnowledgeBase kb
cdef hash_t entity_hash
cdef float entity_freq
@@ -143,7 +142,6 @@ cdef class KnowledgeBase:
cpdef load_bulk(self, loc)
cpdef set_entities(self, entity_list, prob_list, vector_list)
- cpdef set_aliases(self, alias_list, entities_list, probabilities_list)
cdef class Writer:
diff --git a/spacy/kb.pyx b/spacy/kb.pyx
index 9a84439ea..72f66b107 100644
--- a/spacy/kb.pyx
+++ b/spacy/kb.pyx
@@ -1,23 +1,16 @@
# cython: infer_types=True
# cython: profile=True
# coding: utf8
-from collections import OrderedDict
-from pathlib import Path, WindowsPath
-
-from cpython.exc cimport PyErr_CheckSignals
-
-from spacy import util
from spacy.errors import Errors, Warnings, user_warning
+from pathlib import Path
from cymem.cymem cimport Pool
from preshed.maps cimport PreshMap
-from cpython.mem cimport PyMem_Malloc
from cpython.exc cimport PyErr_SetFromErrno
-from libc.stdio cimport FILE, fopen, fclose, fread, fwrite, feof, fseek
+from libc.stdio cimport fopen, fclose, fread, fwrite, feof, fseek
from libc.stdint cimport int32_t, int64_t
-from libc.stdlib cimport qsort
from .typedefs cimport hash_t
@@ -25,7 +18,6 @@ from os import path
from libcpp.vector cimport vector
-
cdef class Candidate:
def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, entity_vector, alias_hash, prior_prob):
@@ -79,8 +71,6 @@ cdef class KnowledgeBase:
self._entry_index = PreshMap()
self._alias_index = PreshMap()
- # Should we initialize self._entries and self._aliases_table to specific starting size ?
-
self.vocab.strings.add("")
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
@@ -165,47 +155,11 @@ cdef class KnowledgeBase:
i += 1
- # TODO: this method is untested
- cpdef set_aliases(self, alias_list, entities_list, probabilities_list):
- nr_aliases = len(alias_list)
- self._alias_index = PreshMap(nr_aliases+1)
- self._aliases_table = alias_vec(nr_aliases+1)
-
- i = 0
- cdef AliasC alias
- cdef int32_t dummy_value = 342
- while i <= nr_aliases:
- alias_hash = self.vocab.strings.add(alias_list[i])
- entities = entities_list[i]
- probabilities = probabilities_list[i]
-
- nr_candidates = len(entities)
- entry_indices = vector[int64_t](nr_candidates)
- probs = vector[float](nr_candidates)
-
- for j in range(0, nr_candidates):
- entity = entities[j]
- entity_hash = self.vocab.strings[entity]
- if not entity_hash in self._entry_index:
- raise ValueError(Errors.E134.format(alias=alias, entity=entity))
-
- entry_index = self._entry_index.get(entity_hash)
- entry_indices[j] = entry_index
-
- alias.entry_indices = entry_indices
- alias.probs = probs
-
- self._aliases_table[i] = alias
- self._alias_index[alias_hash] = i
-
- i += 1
-
def add_alias(self, unicode alias, entities, probabilities):
"""
For a given alias, add its potential entities and prior probabilies to the KB.
Return the alias_hash at the end
"""
-
# Throw an error if the length of entities and probabilities are not the same
if not len(entities) == len(probabilities):
raise ValueError(Errors.E132.format(alias=alias,
diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx
index 99c361964..1c430a90b 100644
--- a/spacy/pipeline/pipes.pyx
+++ b/spacy/pipeline/pipes.pyx
@@ -1068,8 +1068,6 @@ class EntityLinker(Pipe):
DOCS: TODO
"""
name = 'entity_linker'
- context_weight = 1
- prior_weight = 1
@classmethod
def Model(cls, **cfg):
@@ -1078,18 +1076,17 @@ class EntityLinker(Pipe):
embed_width = cfg.get("embed_width", 300)
hidden_width = cfg.get("hidden_width", 128)
-
- # no default because this needs to correspond with the KB entity length
- entity_width = cfg.get("entity_width")
+ entity_width = cfg.get("entity_width") # this needs to correspond with the KB entity length
model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=entity_width, **cfg)
-
return model
def __init__(self, **cfg):
self.model = True
self.kb = None
self.cfg = dict(cfg)
+ self.context_weight = cfg.get("context_weight", 1)
+ self.prior_weight = cfg.get("prior_weight", 1)
def set_kb(self, kb):
self.kb = kb
@@ -1162,7 +1159,6 @@ class EntityLinker(Pipe):
if losses is not None:
losses[self.name] += loss
return loss
-
return 0
def get_loss(self, docs, golds, scores):
@@ -1224,7 +1220,7 @@ class EntityLinker(Pipe):
kb_id = c.entity_
entity_encoding = c.entity_vector
sim = float(cosine(np.asarray([entity_encoding]), context_enc_t)) * self.context_weight
- score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ?
+ score = prior_prob + sim - (prior_prob*sim)
scores.append(score)
# TODO: thresholding