mirror of https://github.com/explosion/spaCy.git
introduce goldparse.links
This commit is contained in:
parent
a5c061f506
commit
0486ccabfd
|
@ -303,8 +303,7 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri
|
|||
collect_correct=True,
|
||||
collect_incorrect=True)
|
||||
|
||||
docs = list()
|
||||
golds = list()
|
||||
data = []
|
||||
|
||||
cnt = 0
|
||||
next_entity_nr = 1
|
||||
|
@ -323,7 +322,7 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri
|
|||
article_doc = nlp(text)
|
||||
truncated_text = text[0:min(doc_cutoff, len(text))]
|
||||
|
||||
gold_entities = dict()
|
||||
gold_entities = list()
|
||||
|
||||
# process all positive and negative entities, collect all relevant mentions in this article
|
||||
for mention, entity_pos in correct_entries[article_id].items():
|
||||
|
@ -337,11 +336,10 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri
|
|||
|
||||
# store gold entities
|
||||
for match_id, start, end in matches:
|
||||
gold_entities[(start, end, entity_pos)] = 1.0
|
||||
gold_entities.append((start, end, entity_pos))
|
||||
|
||||
gold = GoldParse(doc=article_doc, cats=gold_entities)
|
||||
docs.append(article_doc)
|
||||
golds.append(gold)
|
||||
gold = GoldParse(doc=article_doc, links=gold_entities)
|
||||
data.append((article_doc, gold))
|
||||
|
||||
cnt += 1
|
||||
except Exception as e:
|
||||
|
@ -352,7 +350,7 @@ def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_pri
|
|||
print()
|
||||
print("Processed", cnt, "training articles, dev=" + str(dev))
|
||||
print()
|
||||
return docs, golds
|
||||
return data
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import random
|
||||
|
||||
from spacy.util import minibatch, compounding
|
||||
|
||||
from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_creator, training_set_creator, run_el
|
||||
from examples.pipeline.wiki_entity_linking.train_el import EL_Model
|
||||
|
||||
|
@ -23,9 +27,11 @@ VOCAB_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/vocab'
|
|||
|
||||
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
|
||||
|
||||
MAX_CANDIDATES=10
|
||||
MIN_PAIR_OCC=5
|
||||
DOC_CHAR_CUTOFF=300
|
||||
MAX_CANDIDATES = 10
|
||||
MIN_PAIR_OCC = 5
|
||||
DOC_CHAR_CUTOFF = 300
|
||||
EPOCHS = 5
|
||||
DROPOUT = 0.1
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("START", datetime.datetime.now())
|
||||
|
@ -115,7 +121,7 @@ if __name__ == "__main__":
|
|||
if train_pipe:
|
||||
id_to_descr = kb_creator._get_id_to_description(ENTITY_DESCR)
|
||||
|
||||
docs, golds = training_set_creator.read_training(nlp=nlp,
|
||||
train_data = training_set_creator.read_training(nlp=nlp,
|
||||
training_dir=TRAINING_DIR,
|
||||
id_to_descr=id_to_descr,
|
||||
doc_cutoff=DOC_CHAR_CUTOFF,
|
||||
|
@ -123,12 +129,6 @@ if __name__ == "__main__":
|
|||
limit=10,
|
||||
to_print=False)
|
||||
|
||||
# for doc, gold in zip(docs, golds):
|
||||
# print("doc", doc)
|
||||
# for entity, label in gold.cats.items():
|
||||
# print("entity", entity, label)
|
||||
# print()
|
||||
|
||||
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb})
|
||||
nlp.add_pipe(el_pipe, last=True)
|
||||
|
||||
|
@ -136,6 +136,20 @@ if __name__ == "__main__":
|
|||
with nlp.disable_pipes(*other_pipes): # only train Entity Linking
|
||||
nlp.begin_training()
|
||||
|
||||
for itn in range(EPOCHS):
|
||||
random.shuffle(train_data)
|
||||
losses = {}
|
||||
batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
|
||||
for batch in batches:
|
||||
docs, golds = zip(*batch)
|
||||
nlp.update(
|
||||
docs,
|
||||
golds,
|
||||
drop=DROPOUT,
|
||||
losses=losses,
|
||||
)
|
||||
print("Losses", losses)
|
||||
|
||||
### BELOW CODE IS DEPRECATED ###
|
||||
|
||||
# STEP 6: apply the EL algorithm on the training dataset - TODO deprecated - code moved to pipes.pyx
|
||||
|
|
|
@ -31,6 +31,7 @@ cdef class GoldParse:
|
|||
cdef public list ents
|
||||
cdef public dict brackets
|
||||
cdef public object cats
|
||||
cdef public list links
|
||||
|
||||
cdef readonly list cand_to_gold
|
||||
cdef readonly list gold_to_cand
|
||||
|
|
|
@ -427,7 +427,7 @@ cdef class GoldParse:
|
|||
|
||||
def __init__(self, doc, annot_tuples=None, words=None, tags=None,
|
||||
heads=None, deps=None, entities=None, make_projective=False,
|
||||
cats=None, **_):
|
||||
cats=None, links=None, **_):
|
||||
"""Create a GoldParse.
|
||||
|
||||
doc (Doc): The document the annotations refer to.
|
||||
|
@ -450,6 +450,8 @@ cdef class GoldParse:
|
|||
examples of a label to have the value 0.0. Labels not in the
|
||||
dictionary are treated as missing - the gradient for those labels
|
||||
will be zero.
|
||||
links (iterable): A sequence of `(start_char, end_char, kb_id)` tuples,
|
||||
representing the external ID of an entity in a knowledge base.
|
||||
RETURNS (GoldParse): The newly constructed object.
|
||||
"""
|
||||
if words is None:
|
||||
|
@ -485,6 +487,7 @@ cdef class GoldParse:
|
|||
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
|
||||
|
||||
self.cats = {} if cats is None else dict(cats)
|
||||
self.links = links
|
||||
self.words = [None] * len(doc)
|
||||
self.tags = [None] * len(doc)
|
||||
self.heads = [None] * len(doc)
|
||||
|
|
|
@ -1115,48 +1115,61 @@ class EntityLinker(Pipe):
|
|||
self.sgd_mention = create_default_optimizer(self.mention_encoder.ops)
|
||||
|
||||
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
|
||||
""" docs should be a tuple of (entity_docs, article_docs, sentence_docs) TODO """
|
||||
self.require_model()
|
||||
|
||||
if len(docs) != len(golds):
|
||||
raise ValueError(Errors.E077.format(value="loss", n_docs=len(docs),
|
||||
raise ValueError(Errors.E077.format(value="EL training", n_docs=len(docs),
|
||||
n_golds=len(golds)))
|
||||
|
||||
entity_docs, article_docs, sentence_docs = docs
|
||||
assert len(entity_docs) == len(article_docs) == len(sentence_docs)
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
golds = [golds]
|
||||
|
||||
if isinstance(entity_docs, Doc):
|
||||
entity_docs = [entity_docs]
|
||||
article_docs = [article_docs]
|
||||
sentence_docs = [sentence_docs]
|
||||
for doc, gold in zip(docs, golds):
|
||||
print("doc", doc)
|
||||
for entity in gold.links:
|
||||
start, end, gold_kb = entity
|
||||
print("entity", entity)
|
||||
mention = doc[start:end].text
|
||||
print("mention", mention)
|
||||
candidates = self.kb.get_candidates(mention)
|
||||
for c in candidates:
|
||||
prior_prob = c.prior_prob
|
||||
kb_id = c.entity_
|
||||
print("candidate", kb_id, prior_prob)
|
||||
entity_encoding = c.entity_vector
|
||||
print()
|
||||
|
||||
entity_encodings = None #TODO
|
||||
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
|
||||
sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
|
||||
print()
|
||||
|
||||
concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in
|
||||
range(len(article_docs))]
|
||||
mention_encodings, bp_cont = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP)
|
||||
|
||||
loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None)
|
||||
|
||||
mention_gradient = bp_cont(d_scores, sgd=self.sgd_cont)
|
||||
|
||||
# gradient : concat (doc+sent) vs. desc
|
||||
sent_start = self.article_encoder.nO
|
||||
sent_gradients = list()
|
||||
doc_gradients = list()
|
||||
for x in mention_gradient:
|
||||
doc_gradients.append(list(x[0:sent_start]))
|
||||
sent_gradients.append(list(x[sent_start:]))
|
||||
|
||||
bp_doc(doc_gradients, sgd=self.sgd_article)
|
||||
bp_sent(sent_gradients, sgd=self.sgd_sent)
|
||||
|
||||
if losses is not None:
|
||||
losses.setdefault(self.name, 0.0)
|
||||
losses[self.name] += loss
|
||||
return loss
|
||||
# entity_encodings = None #TODO
|
||||
# doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
|
||||
# sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
|
||||
#
|
||||
# concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in
|
||||
# range(len(article_docs))]
|
||||
# mention_encodings, bp_cont = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP)
|
||||
#
|
||||
# loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None)
|
||||
#
|
||||
# mention_gradient = bp_cont(d_scores, sgd=self.sgd_cont)
|
||||
#
|
||||
# # gradient : concat (doc+sent) vs. desc
|
||||
# sent_start = self.article_encoder.nO
|
||||
# sent_gradients = list()
|
||||
# doc_gradients = list()
|
||||
# for x in mention_gradient:
|
||||
# doc_gradients.append(list(x[0:sent_start]))
|
||||
# sent_gradients.append(list(x[sent_start:]))
|
||||
#
|
||||
# bp_doc(doc_gradients, sgd=self.sgd_article)
|
||||
# bp_sent(sent_gradients, sgd=self.sgd_sent)
|
||||
#
|
||||
# if losses is not None:
|
||||
# losses.setdefault(self.name, 0.0)
|
||||
# losses[self.name] += loss
|
||||
# return loss
|
||||
return None
|
||||
|
||||
def get_loss(self, docs, golds, scores):
|
||||
loss, gradients = get_cossim_loss(scores, golds)
|
||||
|
|
Loading…
Reference in New Issue