From a31648d28be3ed10a3f8ba5cefc85f94ce22b715 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 19 Jun 2019 09:15:43 +0200 Subject: [PATCH] further code cleanup --- bin/wiki_entity_linking/kb_creator.py | 36 ++++----- bin/wiki_entity_linking/train_descriptions.py | 7 -- .../training_set_creator.py | 27 +++---- bin/wiki_entity_linking/wikidata_processor.py | 10 +-- .../wikipedia_processor.py | 21 ++--- examples/pipeline/wikidata_entity_linking.py | 77 ++++++++----------- spacy/kb.pxd | 2 - spacy/kb.pyx | 50 +----------- spacy/pipeline/pipes.pyx | 12 +-- 9 files changed, 76 insertions(+), 166 deletions(-) diff --git a/bin/wiki_entity_linking/kb_creator.py b/bin/wiki_entity_linking/kb_creator.py index 8d293a0a1..bd82e5b4e 100644 --- a/bin/wiki_entity_linking/kb_creator.py +++ b/bin/wiki_entity_linking/kb_creator.py @@ -1,31 +1,31 @@ # coding: utf-8 from __future__ import unicode_literals -from bin.wiki_entity_linking.train_descriptions import EntityEncoder +from .train_descriptions import EntityEncoder +from . import wikidata_processor as wd, wikipedia_processor as wp from spacy.kb import KnowledgeBase import csv import datetime -from bin.wiki_entity_linking import wikidata_processor as wd, wikipedia_processor as wp -INPUT_DIM = 300 # dimension of pre-trained vectors -DESC_WIDTH = 64 +INPUT_DIM = 300 # dimension of pre-trained input vectors +DESC_WIDTH = 64 # dimension of output entity vectors def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, entity_def_output, entity_descr_output, - count_input, prior_prob_input, to_print=False): + count_input, prior_prob_input, wikidata_input): # Create the knowledge base from Wikidata entries kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH) # disable this part of the pipeline when rerunning the KB generation from preprocessed files - read_raw_data = False + read_raw_data = True if read_raw_data: print() print(" * _read_wikidata_entities", datetime.datetime.now()) - title_to_id, id_to_descr = wd.read_wikidata_entities_json(limit=None) + title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input) # write the title-ID and ID-description mappings to file _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr) @@ -40,7 +40,7 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, print() entity_frequencies = wp.get_all_frequencies(count_input=count_input) - # filter the entities for in the KB by frequency, because there's just too much data otherwise + # filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise filtered_title_to_id = dict() entity_list = list() description_list = list() @@ -60,11 +60,10 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, print() print(" * train entity encoder", datetime.datetime.now()) print() - encoder = EntityEncoder(nlp, INPUT_DIM, DESC_WIDTH) encoder.train(description_list=description_list, to_print=True) - print() + print() print(" * get entity embeddings", datetime.datetime.now()) print() embeddings = encoder.apply_encoder(description_list) @@ -80,12 +79,10 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, max_entities_per_alias=max_entities_per_alias, min_occ=min_occ, prior_prob_input=prior_prob_input) - if to_print: - print() - print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases()) + print() + print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases()) print("done with kb", datetime.datetime.now()) - return kb @@ -94,6 +91,7 @@ def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_ id_file.write("WP_title" + "|" + "WD_id" + "\n") for title, qid in title_to_id.items(): id_file.write(title + "|" + str(qid) + "\n") + with open(entity_descr_output, mode='w', encoding='utf8') as descr_file: descr_file.write("WD_id" + "|" + "description" + "\n") for qid, descr in id_to_descr.items(): @@ -108,7 +106,6 @@ def get_entity_to_id(entity_def_output): next(csvreader) for row in csvreader: entity_to_id[row[0]] = row[1] - return entity_to_id @@ -120,16 +117,12 @@ def _get_id_to_description(entity_descr_output): next(csvreader) for row in csvreader: id_to_desc[row[0]] = row[1] - return id_to_desc -def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input, to_print=False): +def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input): wp_titles = title_to_id.keys() - if to_print: - print("wp titles:", wp_titles) - # adding aliases with prior probabilities # we can read this file sequentially, it's sorted by alias, and then by count with open(prior_prob_input, mode='r', encoding='utf8') as prior_file: @@ -176,6 +169,3 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in line = prior_file.readline() - if to_print: - print("added", kb.get_size_aliases(), "aliases:", kb.get_alias_strings()) - diff --git a/bin/wiki_entity_linking/train_descriptions.py b/bin/wiki_entity_linking/train_descriptions.py index 82db582dc..948a0e2d1 100644 --- a/bin/wiki_entity_linking/train_descriptions.py +++ b/bin/wiki_entity_linking/train_descriptions.py @@ -32,8 +32,6 @@ class EntityEncoder: if self.encoder is None: raise ValueError("Can not apply encoder before training it") - print("Encoding", len(description_list), "entities") - batch_size = 100000 start = 0 @@ -48,13 +46,11 @@ class EntityEncoder: start = start + batch_size stop = min(stop + batch_size, len(description_list)) - print("encoded :", len(encodings)) return encodings def train(self, description_list, to_print=False): processed, loss = self._train_model(description_list) - if to_print: print("Trained on", processed, "entities across", self.EPOCHS, "epochs") print("Final loss:", loss) @@ -111,15 +107,12 @@ class EntityEncoder: Affine(hidden_with, orig_width) ) self.model = self.encoder >> zero_init(Affine(orig_width, hidden_with, drop_factor=0.0)) - self.sgd = create_default_optimizer(self.model.ops) def _update(self, vectors): predictions, bp_model = self.model.begin_update(np.asarray(vectors), drop=self.DROP) - loss, d_scores = self._get_loss(scores=predictions, golds=np.asarray(vectors)) bp_model(d_scores, sgd=self.sgd) - return loss / len(vectors) @staticmethod diff --git a/bin/wiki_entity_linking/training_set_creator.py b/bin/wiki_entity_linking/training_set_creator.py index 90df5d9fc..eb9f8af78 100644 --- a/bin/wiki_entity_linking/training_set_creator.py +++ b/bin/wiki_entity_linking/training_set_creator.py @@ -18,23 +18,21 @@ Gold-standard entities are stored in one file in standoff format (by character o ENTITY_FILE = "gold_entities_1000000.csv" # use this file for faster processing -def create_training(entity_def_input, training_output): +def create_training(wikipedia_input, entity_def_input, training_output): wp_to_id = kb_creator.get_entity_to_id(entity_def_input) - _process_wikipedia_texts(wp_to_id, training_output, limit=None) + _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None) -def _process_wikipedia_texts(wp_to_id, training_output, limit=None): +def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None): """ Read the XML wikipedia data to parse out training data: raw text data + positive instances """ - title_regex = re.compile(r'(?<=).*(?=)') id_regex = re.compile(r'(?<=)\d*(?=)') read_ids = set() - - entityfile_loc = training_output + "/" + ENTITY_FILE + entityfile_loc = training_output / ENTITY_FILE with open(entityfile_loc, mode="w", encoding='utf8') as entityfile: # write entity training header file _write_training_entity(outputfile=entityfile, @@ -44,7 +42,7 @@ def _process_wikipedia_texts(wp_to_id, training_output, limit=None): start="start", end="end") - with bz2.open(wp.ENWIKI_DUMP, mode='rb') as file: + with bz2.open(wikipedia_input, mode='rb') as file: line = file.readline() cnt = 0 article_text = "" @@ -104,7 +102,7 @@ def _process_wikipedia_texts(wp_to_id, training_output, limit=None): print("Found duplicate article ID", article_id, clean_line) # This should never happen ... read_ids.add(article_id) - # read the title of this article (outside the revision portion of the document) + # read the title of this article (outside the revision portion of the document) if not reading_revision: titles = title_regex.search(clean_line) if titles: @@ -134,7 +132,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te # get the raw text without markup etc, keeping only interwiki links clean_text = _get_clean_wp_text(text) - # read the text char by char to get the right offsets of the interwiki links + # read the text char by char to get the right offsets for the interwiki links final_text = "" open_read = 0 reading_text = True @@ -274,7 +272,7 @@ def _get_clean_wp_text(article_text): def _write_training_article(article_id, clean_text, training_output): - file_loc = training_output + "/" + str(article_id) + ".txt" + file_loc = training_output / str(article_id) + ".txt" with open(file_loc, mode='w', encoding='utf8') as outputfile: outputfile.write(clean_text) @@ -289,11 +287,10 @@ def is_dev(article_id): def read_training(nlp, training_dir, dev, limit): # This method provides training examples that correspond to the entity annotations found by the nlp object - - entityfile_loc = training_dir + "/" + ENTITY_FILE + entityfile_loc = training_dir / ENTITY_FILE data = [] - # we assume the data is written sequentially + # assume the data is written sequentially, so we can reuse the article docs current_article_id = None current_doc = None ents_by_offset = dict() @@ -347,10 +344,10 @@ def read_training(nlp, training_dir, dev, limit): gold_end = int(end) - found_ent.sent.start_char gold_entities = list() gold_entities.append((gold_start, gold_end, wp_title)) - gold = GoldParse(doc=current_doc, links=gold_entities) + gold = GoldParse(doc=sent, links=gold_entities) data.append((sent, gold)) total_entities += 1 - if len(data) % 500 == 0: + if len(data) % 2500 == 0: print(" -read", total_entities, "entities") print(" -read", total_entities, "entities") diff --git a/bin/wiki_entity_linking/wikidata_processor.py b/bin/wiki_entity_linking/wikidata_processor.py index 85d3d8488..a32a0769a 100644 --- a/bin/wiki_entity_linking/wikidata_processor.py +++ b/bin/wiki_entity_linking/wikidata_processor.py @@ -5,17 +5,15 @@ import bz2 import json import datetime -# TODO: remove hardcoded paths -WIKIDATA_JSON = 'C:/Users/Sofie/Documents/data/wikidata/wikidata-20190304-all.json.bz2' - -def read_wikidata_entities_json(limit=None, to_print=False): +def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False): # Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. + # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ lang = 'en' site_filter = 'enwiki' - # filter currently disabled to get ALL data + # properties filter (currently disabled to get ALL data) prop_filter = dict() # prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected @@ -30,7 +28,7 @@ def read_wikidata_entities_json(limit=None, to_print=False): parse_aliases = False parse_claims = False - with bz2.open(WIKIDATA_JSON, mode='rb') as file: + with bz2.open(wikidata_file, mode='rb') as file: line = file.readline() cnt = 0 while line and (not limit or cnt < limit): diff --git a/bin/wiki_entity_linking/wikipedia_processor.py b/bin/wiki_entity_linking/wikipedia_processor.py index d957fc58c..c02e472bc 100644 --- a/bin/wiki_entity_linking/wikipedia_processor.py +++ b/bin/wiki_entity_linking/wikipedia_processor.py @@ -11,11 +11,6 @@ Process a Wikipedia dump to calculate entity frequencies and prior probabilities Write these results to file for downstream KB and training data generation. """ - -# TODO: remove hardcoded paths -ENWIKI_DUMP = 'C:/Users/Sofie/Documents/data/wikipedia/enwiki-20190320-pages-articles-multistream.xml.bz2' -ENWIKI_INDEX = 'C:/Users/Sofie/Documents/data/wikipedia/enwiki-20190320-pages-articles-multistream-index.txt.bz2' - map_alias_to_link = dict() # these will/should be matched ignoring case @@ -46,15 +41,13 @@ for ns in wiki_namespaces: ns_regex = re.compile(ns_regex, re.IGNORECASE) -def read_wikipedia_prior_probs(prior_prob_output): +def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output): """ - Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities - The full file takes about 2h to parse 1100M lines (update printed every 5M lines). - It works relatively fast because we don't care about which article we parsed the interwiki from, - we just process line by line. + Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities. + The full file takes about 2h to parse 1100M lines. + It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from. """ - - with bz2.open(ENWIKI_DUMP, mode='rb') as file: + with bz2.open(wikipedia_input, mode='rb') as file: line = file.readline() cnt = 0 while line: @@ -70,7 +63,7 @@ def read_wikipedia_prior_probs(prior_prob_output): line = file.readline() cnt += 1 - # write all aliases and their entities and occurrences to file + # write all aliases and their entities and count occurrences to file with open(prior_prob_output, mode='w', encoding='utf8') as outputfile: outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n") for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]): @@ -108,7 +101,7 @@ def get_wp_links(text): if ns_regex.match(match): pass # ignore namespaces at the beginning of the string - # this is a simple link, with the alias the same as the mention + # this is a simple [[link]], with the alias the same as the mention elif "|" not in match: aliases.append(match) entities.append(match) diff --git a/examples/pipeline/wikidata_entity_linking.py b/examples/pipeline/wikidata_entity_linking.py index c282c7262..aa1c00996 100644 --- a/examples/pipeline/wikidata_entity_linking.py +++ b/examples/pipeline/wikidata_entity_linking.py @@ -2,35 +2,45 @@ from __future__ import unicode_literals import random - -from spacy.util import minibatch, compounding +import datetime +from pathlib import Path from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp from bin.wiki_entity_linking.kb_creator import DESC_WIDTH import spacy from spacy.kb import KnowledgeBase -import datetime +from spacy.util import minibatch, compounding """ Demonstrate how to build a knowledge base from WikiData and run an Entity Linking algorithm. """ -PRIOR_PROB = 'C:/Users/Sofie/Documents/data/wikipedia/prior_prob.csv' -ENTITY_COUNTS = 'C:/Users/Sofie/Documents/data/wikipedia/entity_freq.csv' -ENTITY_DEFS = 'C:/Users/Sofie/Documents/data/wikipedia/entity_defs.csv' -ENTITY_DESCR = 'C:/Users/Sofie/Documents/data/wikipedia/entity_descriptions.csv' +ROOT_DIR = Path("C:/Users/Sofie/Documents/data/") +OUTPUT_DIR = ROOT_DIR / 'wikipedia' +TRAINING_DIR = OUTPUT_DIR / 'training_data_nel' -KB_FILE = 'C:/Users/Sofie/Documents/data/wikipedia/kb_1/kb' -NLP_1_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_1' -NLP_2_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_2' +PRIOR_PROB = OUTPUT_DIR / 'prior_prob.csv' +ENTITY_COUNTS = OUTPUT_DIR / 'entity_freq.csv' +ENTITY_DEFS = OUTPUT_DIR / 'entity_defs.csv' +ENTITY_DESCR = OUTPUT_DIR / 'entity_descriptions.csv' -TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/' +KB_FILE = OUTPUT_DIR / 'kb_1' / 'kb' +NLP_1_DIR = OUTPUT_DIR / 'nlp_1' +NLP_2_DIR = OUTPUT_DIR / 'nlp_2' +# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ +WIKIDATA_JSON = ROOT_DIR / 'wikidata' / 'wikidata-20190304-all.json.bz2' + +# get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/ +ENWIKI_DUMP = ROOT_DIR / 'wikipedia' / 'enwiki-20190320-pages-articles-multistream.xml.bz2' + +# KB construction parameters MAX_CANDIDATES = 10 MIN_ENTITY_FREQ = 20 MIN_PAIR_OCC = 5 +# model training parameters EPOCHS = 10 DROPOUT = 0.1 LEARN_RATE = 0.005 @@ -38,6 +48,7 @@ L2 = 1e-6 def run_pipeline(): + # set the appropriate booleans to define which parts of the pipeline should be re(run) print("START", datetime.datetime.now()) print() nlp_1 = spacy.load('en_core_web_lg') @@ -67,22 +78,19 @@ def run_pipeline(): to_write_nlp = False to_read_nlp = False - # STEP 1 : create prior probabilities from WP - # run only once ! + # STEP 1 : create prior probabilities from WP (run only once) if to_create_prior_probs: print("STEP 1: to_create_prior_probs", datetime.datetime.now()) - wp.read_wikipedia_prior_probs(prior_prob_output=PRIOR_PROB) + wp.read_wikipedia_prior_probs(wikipedia_input=ENWIKI_DUMP, prior_prob_output=PRIOR_PROB) print() - # STEP 2 : deduce entity frequencies from WP - # run only once ! + # STEP 2 : deduce entity frequencies from WP (run only once) if to_create_entity_counts: print("STEP 2: to_create_entity_counts", datetime.datetime.now()) wp.write_entity_counts(prior_prob_input=PRIOR_PROB, count_output=ENTITY_COUNTS, to_print=False) print() - # STEP 3 : create KB and write to file - # run only once ! + # STEP 3 : create KB and write to file (run only once) if to_create_kb: print("STEP 3a: to_create_kb", datetime.datetime.now()) kb_1 = kb_creator.create_kb(nlp_1, @@ -93,7 +101,7 @@ def run_pipeline(): entity_descr_output=ENTITY_DESCR, count_input=ENTITY_COUNTS, prior_prob_input=PRIOR_PROB, - to_print=False) + wikidata_input=WIKIDATA_JSON) print("kb entities:", kb_1.get_size_entities()) print("kb aliases:", kb_1.get_size_aliases()) print() @@ -121,7 +129,9 @@ def run_pipeline(): # STEP 5: create a training dataset from WP if create_wp_training: print("STEP 5: create training dataset", datetime.datetime.now()) - training_set_creator.create_training(entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR) + training_set_creator.create_training(wikipedia_input=ENWIKI_DUMP, + entity_def_input=ENTITY_DEFS, + training_output=TRAINING_DIR) # STEP 6: create and train the entity linking pipe el_pipe = nlp_2.create_pipe(name='entity_linker', config={}) @@ -136,7 +146,8 @@ def run_pipeline(): if train_pipe: print("STEP 6: training Entity Linking pipe", datetime.datetime.now()) - train_limit = 25000 + # define the size (nr of entities) of training and dev set + train_limit = 10000 dev_limit = 5000 train_data = training_set_creator.read_training(nlp=nlp_2, @@ -157,7 +168,6 @@ def run_pipeline(): if not train_data: print("Did not find any training data") - else: for itn in range(EPOCHS): random.shuffle(train_data) @@ -196,7 +206,7 @@ def run_pipeline(): print() counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2) - print("dev counts:", sorted(counts)) + print("dev counts:", sorted(counts.items(), key=lambda x: x[0])) print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()]) print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()]) print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()]) @@ -215,7 +225,6 @@ def run_pipeline(): dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe) print("dev acc context avg:", round(dev_acc_context, 3), [(x, round(y, 3)) for x, y in dev_acc_context_dict.items()]) - print() # reset for follow-up tests el_pipe.context_weight = 1 @@ -227,7 +236,6 @@ def run_pipeline(): print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now()) print() run_el_toy_example(nlp=nlp_2) - print() # STEP 9: write the NLP pipeline (including entity linker) to file if to_write_nlp: @@ -400,26 +408,9 @@ def run_el_toy_example(nlp): doc = nlp(text) print(text) for ent in doc.ents: - print("ent", ent.text, ent.label_, ent.kb_id_) + print(" ent", ent.text, ent.label_, ent.kb_id_) print() - # Q4426480 is her husband - text = "Ada Lovelace was the countess of Lovelace. She's known for her programming work on the analytical engine. "\ - "She loved her husband William King dearly. " - doc = nlp(text) - print(text) - for ent in doc.ents: - print("ent", ent.text, ent.label_, ent.kb_id_) - print() - - # Q3568763 is her tutor - text = "Ada Lovelace was the countess of Lovelace. She's known for her programming work on the analytical engine. "\ - "She was tutored by her favorite physics tutor William King." - doc = nlp(text) - print(text) - for ent in doc.ents: - print("ent", ent.text, ent.label_, ent.kb_id_) - if __name__ == "__main__": run_pipeline() diff --git a/spacy/kb.pxd b/spacy/kb.pxd index 9c5a73d59..ccf150cd2 100644 --- a/spacy/kb.pxd +++ b/spacy/kb.pxd @@ -18,7 +18,6 @@ ctypedef vector[float_vec] float_matrix # Object used by the Entity Linker that summarizes one entity-alias candidate combination. cdef class Candidate: - cdef readonly KnowledgeBase kb cdef hash_t entity_hash cdef float entity_freq @@ -143,7 +142,6 @@ cdef class KnowledgeBase: cpdef load_bulk(self, loc) cpdef set_entities(self, entity_list, prob_list, vector_list) - cpdef set_aliases(self, alias_list, entities_list, probabilities_list) cdef class Writer: diff --git a/spacy/kb.pyx b/spacy/kb.pyx index 9a84439ea..72f66b107 100644 --- a/spacy/kb.pyx +++ b/spacy/kb.pyx @@ -1,23 +1,16 @@ # cython: infer_types=True # cython: profile=True # coding: utf8 -from collections import OrderedDict -from pathlib import Path, WindowsPath - -from cpython.exc cimport PyErr_CheckSignals - -from spacy import util from spacy.errors import Errors, Warnings, user_warning +from pathlib import Path from cymem.cymem cimport Pool from preshed.maps cimport PreshMap -from cpython.mem cimport PyMem_Malloc from cpython.exc cimport PyErr_SetFromErrno -from libc.stdio cimport FILE, fopen, fclose, fread, fwrite, feof, fseek +from libc.stdio cimport fopen, fclose, fread, fwrite, feof, fseek from libc.stdint cimport int32_t, int64_t -from libc.stdlib cimport qsort from .typedefs cimport hash_t @@ -25,7 +18,6 @@ from os import path from libcpp.vector cimport vector - cdef class Candidate: def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, entity_vector, alias_hash, prior_prob): @@ -79,8 +71,6 @@ cdef class KnowledgeBase: self._entry_index = PreshMap() self._alias_index = PreshMap() - # Should we initialize self._entries and self._aliases_table to specific starting size ? - self.vocab.strings.add("") self._create_empty_vectors(dummy_hash=self.vocab.strings[""]) @@ -165,47 +155,11 @@ cdef class KnowledgeBase: i += 1 - # TODO: this method is untested - cpdef set_aliases(self, alias_list, entities_list, probabilities_list): - nr_aliases = len(alias_list) - self._alias_index = PreshMap(nr_aliases+1) - self._aliases_table = alias_vec(nr_aliases+1) - - i = 0 - cdef AliasC alias - cdef int32_t dummy_value = 342 - while i <= nr_aliases: - alias_hash = self.vocab.strings.add(alias_list[i]) - entities = entities_list[i] - probabilities = probabilities_list[i] - - nr_candidates = len(entities) - entry_indices = vector[int64_t](nr_candidates) - probs = vector[float](nr_candidates) - - for j in range(0, nr_candidates): - entity = entities[j] - entity_hash = self.vocab.strings[entity] - if not entity_hash in self._entry_index: - raise ValueError(Errors.E134.format(alias=alias, entity=entity)) - - entry_index = self._entry_index.get(entity_hash) - entry_indices[j] = entry_index - - alias.entry_indices = entry_indices - alias.probs = probs - - self._aliases_table[i] = alias - self._alias_index[alias_hash] = i - - i += 1 - def add_alias(self, unicode alias, entities, probabilities): """ For a given alias, add its potential entities and prior probabilies to the KB. Return the alias_hash at the end """ - # Throw an error if the length of entities and probabilities are not the same if not len(entities) == len(probabilities): raise ValueError(Errors.E132.format(alias=alias, diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 99c361964..1c430a90b 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1068,8 +1068,6 @@ class EntityLinker(Pipe): DOCS: TODO """ name = 'entity_linker' - context_weight = 1 - prior_weight = 1 @classmethod def Model(cls, **cfg): @@ -1078,18 +1076,17 @@ class EntityLinker(Pipe): embed_width = cfg.get("embed_width", 300) hidden_width = cfg.get("hidden_width", 128) - - # no default because this needs to correspond with the KB entity length - entity_width = cfg.get("entity_width") + entity_width = cfg.get("entity_width") # this needs to correspond with the KB entity length model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=entity_width, **cfg) - return model def __init__(self, **cfg): self.model = True self.kb = None self.cfg = dict(cfg) + self.context_weight = cfg.get("context_weight", 1) + self.prior_weight = cfg.get("prior_weight", 1) def set_kb(self, kb): self.kb = kb @@ -1162,7 +1159,6 @@ class EntityLinker(Pipe): if losses is not None: losses[self.name] += loss return loss - return 0 def get_loss(self, docs, golds, scores): @@ -1224,7 +1220,7 @@ class EntityLinker(Pipe): kb_id = c.entity_ entity_encoding = c.entity_vector sim = float(cosine(np.asarray([entity_encoding]), context_enc_t)) * self.context_weight - score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ? + score = prior_prob + sim - (prior_prob*sim) scores.append(score) # TODO: thresholding