introduce goldparse.links

This commit is contained in:
svlandeg 2019-06-07 13:54:45 +02:00
parent a5c061f506
commit 0486ccabfd
5 changed files with 82 additions and 53 deletions

View File

@ -303,8 +303,7 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri
collect_correct=True,
collect_incorrect=True)
docs = list()
golds = list()
data = []
cnt = 0
next_entity_nr = 1
@ -323,7 +322,7 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri
article_doc = nlp(text)
truncated_text = text[0:min(doc_cutoff, len(text))]
gold_entities = dict()
gold_entities = list()
# process all positive and negative entities, collect all relevant mentions in this article
for mention, entity_pos in correct_entries[article_id].items():
@ -337,11 +336,10 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri
# store gold entities
for match_id, start, end in matches:
gold_entities[(start, end, entity_pos)] = 1.0
gold_entities.append((start, end, entity_pos))
gold = GoldParse(doc=article_doc, cats=gold_entities)
docs.append(article_doc)
golds.append(gold)
gold = GoldParse(doc=article_doc, links=gold_entities)
data.append((article_doc, gold))
cnt += 1
except Exception as e:
@ -352,7 +350,7 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri
print()
print("Processed", cnt, "training articles, dev=" + str(dev))
print()
return docs, golds
return data

View File

@ -1,6 +1,10 @@
# coding: utf-8
from __future__ import unicode_literals
import random
from spacy.util import minibatch, compounding
from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_creator, training_set_creator, run_el
from examples.pipeline.wiki_entity_linking.train_el import EL_Model
@ -23,9 +27,11 @@ VOCAB_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/vocab'
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
MAX_CANDIDATES=10
MIN_PAIR_OCC=5
DOC_CHAR_CUTOFF=300
MAX_CANDIDATES = 10
MIN_PAIR_OCC = 5
DOC_CHAR_CUTOFF = 300
EPOCHS = 5
DROPOUT = 0.1
if __name__ == "__main__":
print("START", datetime.datetime.now())
@ -115,7 +121,7 @@ if __name__ == "__main__":
if train_pipe:
id_to_descr = kb_creator._get_id_to_description(ENTITY_DESCR)
docs, golds = training_set_creator.read_training(nlp=nlp,
train_data = training_set_creator.read_training(nlp=nlp,
training_dir=TRAINING_DIR,
id_to_descr=id_to_descr,
doc_cutoff=DOC_CHAR_CUTOFF,
@ -123,12 +129,6 @@ if __name__ == "__main__":
limit=10,
to_print=False)
# for doc, gold in zip(docs, golds):
# print("doc", doc)
# for entity, label in gold.cats.items():
# print("entity", entity, label)
# print()
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb})
nlp.add_pipe(el_pipe, last=True)
@ -136,6 +136,20 @@ if __name__ == "__main__":
with nlp.disable_pipes(*other_pipes): # only train Entity Linking
nlp.begin_training()
for itn in range(EPOCHS):
random.shuffle(train_data)
losses = {}
batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
for batch in batches:
docs, golds = zip(*batch)
nlp.update(
docs,
golds,
drop=DROPOUT,
losses=losses,
)
print("Losses", losses)
### BELOW CODE IS DEPRECATED ###
# STEP 6: apply the EL algorithm on the training dataset - TODO deprecated - code moved to pipes.pyx

View File

@ -31,6 +31,7 @@ cdef class GoldParse:
cdef public list ents
cdef public dict brackets
cdef public object cats
cdef public list links
cdef readonly list cand_to_gold
cdef readonly list gold_to_cand

View File

@ -427,7 +427,7 @@ cdef class GoldParse:
def __init__(self, doc, annot_tuples=None, words=None, tags=None,
heads=None, deps=None, entities=None, make_projective=False,
cats=None, **_):
cats=None, links=None, **_):
"""Create a GoldParse.
doc (Doc): The document the annotations refer to.
@ -450,6 +450,8 @@ cdef class GoldParse:
examples of a label to have the value 0.0. Labels not in the
dictionary are treated as missing - the gradient for those labels
will be zero.
links (iterable): A sequence of `(start_char, end_char, kb_id)` tuples,
representing the external ID of an entity in a knowledge base.
RETURNS (GoldParse): The newly constructed object.
"""
if words is None:
@ -485,6 +487,7 @@ cdef class GoldParse:
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
self.cats = {} if cats is None else dict(cats)
self.links = links
self.words = [None] * len(doc)
self.tags = [None] * len(doc)
self.heads = [None] * len(doc)

View File

@ -1115,48 +1115,61 @@ class EntityLinker(Pipe):
self.sgd_mention = create_default_optimizer(self.mention_encoder.ops)
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
""" docs should be a tuple of (entity_docs, article_docs, sentence_docs) TODO """
self.require_model()
if len(docs) != len(golds):
raise ValueError(Errors.E077.format(value="loss", n_docs=len(docs),
raise ValueError(Errors.E077.format(value="EL training", n_docs=len(docs),
n_golds=len(golds)))
entity_docs, article_docs, sentence_docs = docs
assert len(entity_docs) == len(article_docs) == len(sentence_docs)
if isinstance(docs, Doc):
docs = [docs]
golds = [golds]
if isinstance(entity_docs, Doc):
entity_docs = [entity_docs]
article_docs = [article_docs]
sentence_docs = [sentence_docs]
for doc, gold in zip(docs, golds):
print("doc", doc)
for entity in gold.links:
start, end, gold_kb = entity
print("entity", entity)
mention = doc[start:end].text
print("mention", mention)
candidates = self.kb.get_candidates(mention)
for c in candidates:
prior_prob = c.prior_prob
kb_id = c.entity_
print("candidate", kb_id, prior_prob)
entity_encoding = c.entity_vector
print()
entity_encodings = None #TODO
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
print()
concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in
range(len(article_docs))]
mention_encodings, bp_cont = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP)
loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None)
mention_gradient = bp_cont(d_scores, sgd=self.sgd_cont)
# gradient : concat (doc+sent) vs. desc
sent_start = self.article_encoder.nO
sent_gradients = list()
doc_gradients = list()
for x in mention_gradient:
doc_gradients.append(list(x[0:sent_start]))
sent_gradients.append(list(x[sent_start:]))
bp_doc(doc_gradients, sgd=self.sgd_article)
bp_sent(sent_gradients, sgd=self.sgd_sent)
if losses is not None:
losses.setdefault(self.name, 0.0)
losses[self.name] += loss
return loss
# entity_encodings = None #TODO
# doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
# sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
#
# concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in
# range(len(article_docs))]
# mention_encodings, bp_cont = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP)
#
# loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None)
#
# mention_gradient = bp_cont(d_scores, sgd=self.sgd_cont)
#
# # gradient : concat (doc+sent) vs. desc
# sent_start = self.article_encoder.nO
# sent_gradients = list()
# doc_gradients = list()
# for x in mention_gradient:
# doc_gradients.append(list(x[0:sent_start]))
# sent_gradients.append(list(x[sent_start:]))
#
# bp_doc(doc_gradients, sgd=self.sgd_article)
# bp_sent(sent_gradients, sgd=self.sgd_sent)
#
# if losses is not None:
# losses.setdefault(self.name, 0.0)
# losses[self.name] += loss
# return loss
return None
def get_loss(self, docs, golds, scores):
loss, gradients = get_cossim_loss(scores, golds)