further code cleanup

This commit is contained in:
svlandeg 2019-06-19 09:15:43 +02:00
parent 478305cd3f
commit a31648d28b
9 changed files with 76 additions and 166 deletions

View File

@ -1,31 +1,31 @@
# coding: utf-8
from __future__ import unicode_literals
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
from .train_descriptions import EntityEncoder
from . import wikidata_processor as wd, wikipedia_processor as wp
from spacy.kb import KnowledgeBase
import csv
import datetime
from bin.wiki_entity_linking import wikidata_processor as wd, wikipedia_processor as wp
INPUT_DIM = 300 # dimension of pre-trained vectors
DESC_WIDTH = 64
INPUT_DIM = 300 # dimension of pre-trained input vectors
DESC_WIDTH = 64 # dimension of output entity vectors
def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
entity_def_output, entity_descr_output,
count_input, prior_prob_input, to_print=False):
count_input, prior_prob_input, wikidata_input):
# Create the knowledge base from Wikidata entries
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
# disable this part of the pipeline when rerunning the KB generation from preprocessed files
read_raw_data = False
read_raw_data = True
if read_raw_data:
print()
print(" * _read_wikidata_entities", datetime.datetime.now())
title_to_id, id_to_descr = wd.read_wikidata_entities_json(limit=None)
title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input)
# write the title-ID and ID-description mappings to file
_write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr)
@ -40,7 +40,7 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
print()
entity_frequencies = wp.get_all_frequencies(count_input=count_input)
# filter the entities for in the KB by frequency, because there's just too much data otherwise
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
filtered_title_to_id = dict()
entity_list = list()
description_list = list()
@ -60,11 +60,10 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
print()
print(" * train entity encoder", datetime.datetime.now())
print()
encoder = EntityEncoder(nlp, INPUT_DIM, DESC_WIDTH)
encoder.train(description_list=description_list, to_print=True)
print()
print()
print(" * get entity embeddings", datetime.datetime.now())
print()
embeddings = encoder.apply_encoder(description_list)
@ -80,12 +79,10 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
max_entities_per_alias=max_entities_per_alias, min_occ=min_occ,
prior_prob_input=prior_prob_input)
if to_print:
print()
print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases())
print()
print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases())
print("done with kb", datetime.datetime.now())
return kb
@ -94,6 +91,7 @@ def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_
id_file.write("WP_title" + "|" + "WD_id" + "\n")
for title, qid in title_to_id.items():
id_file.write(title + "|" + str(qid) + "\n")
with open(entity_descr_output, mode='w', encoding='utf8') as descr_file:
descr_file.write("WD_id" + "|" + "description" + "\n")
for qid, descr in id_to_descr.items():
@ -108,7 +106,6 @@ def get_entity_to_id(entity_def_output):
next(csvreader)
for row in csvreader:
entity_to_id[row[0]] = row[1]
return entity_to_id
@ -120,16 +117,12 @@ def _get_id_to_description(entity_descr_output):
next(csvreader)
for row in csvreader:
id_to_desc[row[0]] = row[1]
return id_to_desc
def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input, to_print=False):
def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input):
wp_titles = title_to_id.keys()
if to_print:
print("wp titles:", wp_titles)
# adding aliases with prior probabilities
# we can read this file sequentially, it's sorted by alias, and then by count
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
@ -176,6 +169,3 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
line = prior_file.readline()
if to_print:
print("added", kb.get_size_aliases(), "aliases:", kb.get_alias_strings())

View File

@ -32,8 +32,6 @@ class EntityEncoder:
if self.encoder is None:
raise ValueError("Can not apply encoder before training it")
print("Encoding", len(description_list), "entities")
batch_size = 100000
start = 0
@ -48,13 +46,11 @@ class EntityEncoder:
start = start + batch_size
stop = min(stop + batch_size, len(description_list))
print("encoded :", len(encodings))
return encodings
def train(self, description_list, to_print=False):
processed, loss = self._train_model(description_list)
if to_print:
print("Trained on", processed, "entities across", self.EPOCHS, "epochs")
print("Final loss:", loss)
@ -111,15 +107,12 @@ class EntityEncoder:
Affine(hidden_with, orig_width)
)
self.model = self.encoder >> zero_init(Affine(orig_width, hidden_with, drop_factor=0.0))
self.sgd = create_default_optimizer(self.model.ops)
def _update(self, vectors):
predictions, bp_model = self.model.begin_update(np.asarray(vectors), drop=self.DROP)
loss, d_scores = self._get_loss(scores=predictions, golds=np.asarray(vectors))
bp_model(d_scores, sgd=self.sgd)
return loss / len(vectors)
@staticmethod

View File

@ -18,23 +18,21 @@ Gold-standard entities are stored in one file in standoff format (by character o
ENTITY_FILE = "gold_entities_1000000.csv" # use this file for faster processing
def create_training(entity_def_input, training_output):
def create_training(wikipedia_input, entity_def_input, training_output):
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
_process_wikipedia_texts(wp_to_id, training_output, limit=None)
_process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None)
def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None):
"""
Read the XML wikipedia data to parse out training data:
raw text data + positive instances
"""
title_regex = re.compile(r'(?<=<title>).*(?=</title>)')
id_regex = re.compile(r'(?<=<id>)\d*(?=</id>)')
read_ids = set()
entityfile_loc = training_output + "/" + ENTITY_FILE
entityfile_loc = training_output / ENTITY_FILE
with open(entityfile_loc, mode="w", encoding='utf8') as entityfile:
# write entity training header file
_write_training_entity(outputfile=entityfile,
@ -44,7 +42,7 @@ def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
start="start",
end="end")
with bz2.open(wp.ENWIKI_DUMP, mode='rb') as file:
with bz2.open(wikipedia_input, mode='rb') as file:
line = file.readline()
cnt = 0
article_text = ""
@ -104,7 +102,7 @@ def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
print("Found duplicate article ID", article_id, clean_line) # This should never happen ...
read_ids.add(article_id)
# read the title of this article (outside the revision portion of the document)
# read the title of this article (outside the revision portion of the document)
if not reading_revision:
titles = title_regex.search(clean_line)
if titles:
@ -134,7 +132,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
# get the raw text without markup etc, keeping only interwiki links
clean_text = _get_clean_wp_text(text)
# read the text char by char to get the right offsets of the interwiki links
# read the text char by char to get the right offsets for the interwiki links
final_text = ""
open_read = 0
reading_text = True
@ -274,7 +272,7 @@ def _get_clean_wp_text(article_text):
def _write_training_article(article_id, clean_text, training_output):
file_loc = training_output + "/" + str(article_id) + ".txt"
file_loc = training_output / str(article_id) + ".txt"
with open(file_loc, mode='w', encoding='utf8') as outputfile:
outputfile.write(clean_text)
@ -289,11 +287,10 @@ def is_dev(article_id):
def read_training(nlp, training_dir, dev, limit):
# This method provides training examples that correspond to the entity annotations found by the nlp object
entityfile_loc = training_dir + "/" + ENTITY_FILE
entityfile_loc = training_dir / ENTITY_FILE
data = []
# we assume the data is written sequentially
# assume the data is written sequentially, so we can reuse the article docs
current_article_id = None
current_doc = None
ents_by_offset = dict()
@ -347,10 +344,10 @@ def read_training(nlp, training_dir, dev, limit):
gold_end = int(end) - found_ent.sent.start_char
gold_entities = list()
gold_entities.append((gold_start, gold_end, wp_title))
gold = GoldParse(doc=current_doc, links=gold_entities)
gold = GoldParse(doc=sent, links=gold_entities)
data.append((sent, gold))
total_entities += 1
if len(data) % 500 == 0:
if len(data) % 2500 == 0:
print(" -read", total_entities, "entities")
print(" -read", total_entities, "entities")

View File

@ -5,17 +5,15 @@ import bz2
import json
import datetime
# TODO: remove hardcoded paths
WIKIDATA_JSON = 'C:/Users/Sofie/Documents/data/wikidata/wikidata-20190304-all.json.bz2'
def read_wikidata_entities_json(limit=None, to_print=False):
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
# Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
lang = 'en'
site_filter = 'enwiki'
# filter currently disabled to get ALL data
# properties filter (currently disabled to get ALL data)
prop_filter = dict()
# prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected
@ -30,7 +28,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
parse_aliases = False
parse_claims = False
with bz2.open(WIKIDATA_JSON, mode='rb') as file:
with bz2.open(wikidata_file, mode='rb') as file:
line = file.readline()
cnt = 0
while line and (not limit or cnt < limit):

View File

@ -11,11 +11,6 @@ Process a Wikipedia dump to calculate entity frequencies and prior probabilities
Write these results to file for downstream KB and training data generation.
"""
# TODO: remove hardcoded paths
ENWIKI_DUMP = 'C:/Users/Sofie/Documents/data/wikipedia/enwiki-20190320-pages-articles-multistream.xml.bz2'
ENWIKI_INDEX = 'C:/Users/Sofie/Documents/data/wikipedia/enwiki-20190320-pages-articles-multistream-index.txt.bz2'
map_alias_to_link = dict()
# these will/should be matched ignoring case
@ -46,15 +41,13 @@ for ns in wiki_namespaces:
ns_regex = re.compile(ns_regex, re.IGNORECASE)
def read_wikipedia_prior_probs(prior_prob_output):
def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
"""
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities
The full file takes about 2h to parse 1100M lines (update printed every 5M lines).
It works relatively fast because we don't care about which article we parsed the interwiki from,
we just process line by line.
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
The full file takes about 2h to parse 1100M lines.
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
"""
with bz2.open(ENWIKI_DUMP, mode='rb') as file:
with bz2.open(wikipedia_input, mode='rb') as file:
line = file.readline()
cnt = 0
while line:
@ -70,7 +63,7 @@ def read_wikipedia_prior_probs(prior_prob_output):
line = file.readline()
cnt += 1
# write all aliases and their entities and occurrences to file
# write all aliases and their entities and count occurrences to file
with open(prior_prob_output, mode='w', encoding='utf8') as outputfile:
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
@ -108,7 +101,7 @@ def get_wp_links(text):
if ns_regex.match(match):
pass # ignore namespaces at the beginning of the string
# this is a simple link, with the alias the same as the mention
# this is a simple [[link]], with the alias the same as the mention
elif "|" not in match:
aliases.append(match)
entities.append(match)

View File

@ -2,35 +2,45 @@
from __future__ import unicode_literals
import random
from spacy.util import minibatch, compounding
import datetime
from pathlib import Path
from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp
from bin.wiki_entity_linking.kb_creator import DESC_WIDTH
import spacy
from spacy.kb import KnowledgeBase
import datetime
from spacy.util import minibatch, compounding
"""
Demonstrate how to build a knowledge base from WikiData and run an Entity Linking algorithm.
"""
PRIOR_PROB = 'C:/Users/Sofie/Documents/data/wikipedia/prior_prob.csv'
ENTITY_COUNTS = 'C:/Users/Sofie/Documents/data/wikipedia/entity_freq.csv'
ENTITY_DEFS = 'C:/Users/Sofie/Documents/data/wikipedia/entity_defs.csv'
ENTITY_DESCR = 'C:/Users/Sofie/Documents/data/wikipedia/entity_descriptions.csv'
ROOT_DIR = Path("C:/Users/Sofie/Documents/data/")
OUTPUT_DIR = ROOT_DIR / 'wikipedia'
TRAINING_DIR = OUTPUT_DIR / 'training_data_nel'
KB_FILE = 'C:/Users/Sofie/Documents/data/wikipedia/kb_1/kb'
NLP_1_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_1'
NLP_2_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_2'
PRIOR_PROB = OUTPUT_DIR / 'prior_prob.csv'
ENTITY_COUNTS = OUTPUT_DIR / 'entity_freq.csv'
ENTITY_DEFS = OUTPUT_DIR / 'entity_defs.csv'
ENTITY_DESCR = OUTPUT_DIR / 'entity_descriptions.csv'
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
KB_FILE = OUTPUT_DIR / 'kb_1' / 'kb'
NLP_1_DIR = OUTPUT_DIR / 'nlp_1'
NLP_2_DIR = OUTPUT_DIR / 'nlp_2'
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
WIKIDATA_JSON = ROOT_DIR / 'wikidata' / 'wikidata-20190304-all.json.bz2'
# get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/
ENWIKI_DUMP = ROOT_DIR / 'wikipedia' / 'enwiki-20190320-pages-articles-multistream.xml.bz2'
# KB construction parameters
MAX_CANDIDATES = 10
MIN_ENTITY_FREQ = 20
MIN_PAIR_OCC = 5
# model training parameters
EPOCHS = 10
DROPOUT = 0.1
LEARN_RATE = 0.005
@ -38,6 +48,7 @@ L2 = 1e-6
def run_pipeline():
# set the appropriate booleans to define which parts of the pipeline should be re(run)
print("START", datetime.datetime.now())
print()
nlp_1 = spacy.load('en_core_web_lg')
@ -67,22 +78,19 @@ def run_pipeline():
to_write_nlp = False
to_read_nlp = False
# STEP 1 : create prior probabilities from WP
# run only once !
# STEP 1 : create prior probabilities from WP (run only once)
if to_create_prior_probs:
print("STEP 1: to_create_prior_probs", datetime.datetime.now())
wp.read_wikipedia_prior_probs(prior_prob_output=PRIOR_PROB)
wp.read_wikipedia_prior_probs(wikipedia_input=ENWIKI_DUMP, prior_prob_output=PRIOR_PROB)
print()
# STEP 2 : deduce entity frequencies from WP
# run only once !
# STEP 2 : deduce entity frequencies from WP (run only once)
if to_create_entity_counts:
print("STEP 2: to_create_entity_counts", datetime.datetime.now())
wp.write_entity_counts(prior_prob_input=PRIOR_PROB, count_output=ENTITY_COUNTS, to_print=False)
print()
# STEP 3 : create KB and write to file
# run only once !
# STEP 3 : create KB and write to file (run only once)
if to_create_kb:
print("STEP 3a: to_create_kb", datetime.datetime.now())
kb_1 = kb_creator.create_kb(nlp_1,
@ -93,7 +101,7 @@ def run_pipeline():
entity_descr_output=ENTITY_DESCR,
count_input=ENTITY_COUNTS,
prior_prob_input=PRIOR_PROB,
to_print=False)
wikidata_input=WIKIDATA_JSON)
print("kb entities:", kb_1.get_size_entities())
print("kb aliases:", kb_1.get_size_aliases())
print()
@ -121,7 +129,9 @@ def run_pipeline():
# STEP 5: create a training dataset from WP
if create_wp_training:
print("STEP 5: create training dataset", datetime.datetime.now())
training_set_creator.create_training(entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
training_set_creator.create_training(wikipedia_input=ENWIKI_DUMP,
entity_def_input=ENTITY_DEFS,
training_output=TRAINING_DIR)
# STEP 6: create and train the entity linking pipe
el_pipe = nlp_2.create_pipe(name='entity_linker', config={})
@ -136,7 +146,8 @@ def run_pipeline():
if train_pipe:
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
train_limit = 25000
# define the size (nr of entities) of training and dev set
train_limit = 10000
dev_limit = 5000
train_data = training_set_creator.read_training(nlp=nlp_2,
@ -157,7 +168,6 @@ def run_pipeline():
if not train_data:
print("Did not find any training data")
else:
for itn in range(EPOCHS):
random.shuffle(train_data)
@ -196,7 +206,7 @@ def run_pipeline():
print()
counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2)
print("dev counts:", sorted(counts))
print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()])
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()])
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()])
@ -215,7 +225,6 @@ def run_pipeline():
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
print("dev acc context avg:", round(dev_acc_context, 3),
[(x, round(y, 3)) for x, y in dev_acc_context_dict.items()])
print()
# reset for follow-up tests
el_pipe.context_weight = 1
@ -227,7 +236,6 @@ def run_pipeline():
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
print()
run_el_toy_example(nlp=nlp_2)
print()
# STEP 9: write the NLP pipeline (including entity linker) to file
if to_write_nlp:
@ -400,26 +408,9 @@ def run_el_toy_example(nlp):
doc = nlp(text)
print(text)
for ent in doc.ents:
print("ent", ent.text, ent.label_, ent.kb_id_)
print(" ent", ent.text, ent.label_, ent.kb_id_)
print()
# Q4426480 is her husband
text = "Ada Lovelace was the countess of Lovelace. She's known for her programming work on the analytical engine. "\
"She loved her husband William King dearly. "
doc = nlp(text)
print(text)
for ent in doc.ents:
print("ent", ent.text, ent.label_, ent.kb_id_)
print()
# Q3568763 is her tutor
text = "Ada Lovelace was the countess of Lovelace. She's known for her programming work on the analytical engine. "\
"She was tutored by her favorite physics tutor William King."
doc = nlp(text)
print(text)
for ent in doc.ents:
print("ent", ent.text, ent.label_, ent.kb_id_)
if __name__ == "__main__":
run_pipeline()

View File

@ -18,7 +18,6 @@ ctypedef vector[float_vec] float_matrix
# Object used by the Entity Linker that summarizes one entity-alias candidate combination.
cdef class Candidate:
cdef readonly KnowledgeBase kb
cdef hash_t entity_hash
cdef float entity_freq
@ -143,7 +142,6 @@ cdef class KnowledgeBase:
cpdef load_bulk(self, loc)
cpdef set_entities(self, entity_list, prob_list, vector_list)
cpdef set_aliases(self, alias_list, entities_list, probabilities_list)
cdef class Writer:

View File

@ -1,23 +1,16 @@
# cython: infer_types=True
# cython: profile=True
# coding: utf8
from collections import OrderedDict
from pathlib import Path, WindowsPath
from cpython.exc cimport PyErr_CheckSignals
from spacy import util
from spacy.errors import Errors, Warnings, user_warning
from pathlib import Path
from cymem.cymem cimport Pool
from preshed.maps cimport PreshMap
from cpython.mem cimport PyMem_Malloc
from cpython.exc cimport PyErr_SetFromErrno
from libc.stdio cimport FILE, fopen, fclose, fread, fwrite, feof, fseek
from libc.stdio cimport fopen, fclose, fread, fwrite, feof, fseek
from libc.stdint cimport int32_t, int64_t
from libc.stdlib cimport qsort
from .typedefs cimport hash_t
@ -25,7 +18,6 @@ from os import path
from libcpp.vector cimport vector
cdef class Candidate:
def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, entity_vector, alias_hash, prior_prob):
@ -79,8 +71,6 @@ cdef class KnowledgeBase:
self._entry_index = PreshMap()
self._alias_index = PreshMap()
# Should we initialize self._entries and self._aliases_table to specific starting size ?
self.vocab.strings.add("")
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
@ -165,47 +155,11 @@ cdef class KnowledgeBase:
i += 1
# TODO: this method is untested
cpdef set_aliases(self, alias_list, entities_list, probabilities_list):
nr_aliases = len(alias_list)
self._alias_index = PreshMap(nr_aliases+1)
self._aliases_table = alias_vec(nr_aliases+1)
i = 0
cdef AliasC alias
cdef int32_t dummy_value = 342
while i <= nr_aliases:
alias_hash = self.vocab.strings.add(alias_list[i])
entities = entities_list[i]
probabilities = probabilities_list[i]
nr_candidates = len(entities)
entry_indices = vector[int64_t](nr_candidates)
probs = vector[float](nr_candidates)
for j in range(0, nr_candidates):
entity = entities[j]
entity_hash = self.vocab.strings[entity]
if not entity_hash in self._entry_index:
raise ValueError(Errors.E134.format(alias=alias, entity=entity))
entry_index = <int64_t>self._entry_index.get(entity_hash)
entry_indices[j] = entry_index
alias.entry_indices = entry_indices
alias.probs = probs
self._aliases_table[i] = alias
self._alias_index[alias_hash] = i
i += 1
def add_alias(self, unicode alias, entities, probabilities):
"""
For a given alias, add its potential entities and prior probabilies to the KB.
Return the alias_hash at the end
"""
# Throw an error if the length of entities and probabilities are not the same
if not len(entities) == len(probabilities):
raise ValueError(Errors.E132.format(alias=alias,

View File

@ -1068,8 +1068,6 @@ class EntityLinker(Pipe):
DOCS: TODO
"""
name = 'entity_linker'
context_weight = 1
prior_weight = 1
@classmethod
def Model(cls, **cfg):
@ -1078,18 +1076,17 @@ class EntityLinker(Pipe):
embed_width = cfg.get("embed_width", 300)
hidden_width = cfg.get("hidden_width", 128)
# no default because this needs to correspond with the KB entity length
entity_width = cfg.get("entity_width")
entity_width = cfg.get("entity_width") # this needs to correspond with the KB entity length
model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=entity_width, **cfg)
return model
def __init__(self, **cfg):
self.model = True
self.kb = None
self.cfg = dict(cfg)
self.context_weight = cfg.get("context_weight", 1)
self.prior_weight = cfg.get("prior_weight", 1)
def set_kb(self, kb):
self.kb = kb
@ -1162,7 +1159,6 @@ class EntityLinker(Pipe):
if losses is not None:
losses[self.name] += loss
return loss
return 0
def get_loss(self, docs, golds, scores):
@ -1224,7 +1220,7 @@ class EntityLinker(Pipe):
kb_id = c.entity_
entity_encoding = c.entity_vector
sim = float(cosine(np.asarray([entity_encoding]), context_enc_t)) * self.context_weight
score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ?
score = prior_prob + sim - (prior_prob*sim)
scores.append(score)
# TODO: thresholding