From 7de1ee69b819cba8b66db370dcb1ec169b4a7b74 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Fri, 7 Jun 2019 15:55:10 +0200 Subject: [PATCH] training loop in proper pipe format --- .../wiki_entity_linking/wiki_nel_pipeline.py | 13 +-- spacy/pipeline/pipes.pyx | 84 ++++++++++--------- 2 files changed, 49 insertions(+), 48 deletions(-) diff --git a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py index b66f8b316..ded4bdc24 100644 --- a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py +++ b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py @@ -126,7 +126,7 @@ if __name__ == "__main__": id_to_descr=id_to_descr, doc_cutoff=DOC_CHAR_CUTOFF, dev=False, - limit=10, + limit=100, to_print=False) el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb}) @@ -137,6 +137,8 @@ if __name__ == "__main__": nlp.begin_training() for itn in range(EPOCHS): + print() + print("EPOCH", itn) random.shuffle(train_data) losses = {} batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001)) @@ -150,15 +152,6 @@ if __name__ == "__main__": ) print("Losses", losses) - ### BELOW CODE IS DEPRECATED ### - - # STEP 6: apply the EL algorithm on the training dataset - TODO deprecated - code moved to pipes.pyx - if run_el_training: - print("STEP 6: training", datetime.datetime.now()) - trainer = EL_Model(kb=my_kb, nlp=nlp) - trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=10000, devlimit=500) - print() - # STEP 7: apply the EL algorithm on the dev dataset (TODO: overlaps with code from run_el_training ?) if apply_to_dev: run_el.run_el_dev(kb=my_kb, nlp=nlp, training_dir=TRAINING_DIR, limit=2000) diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index f15ffd036..01302b618 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1125,51 +1125,59 @@ class EntityLinker(Pipe): docs = [docs] golds = [golds] + article_docs = list() + sentence_docs = list() + entity_encodings = list() + for doc, gold in zip(docs, golds): - print("doc", doc) for entity in gold.links: start, end, gold_kb = entity - print("entity", entity) - mention = doc[start:end].text - print("mention", mention) - candidates = self.kb.get_candidates(mention) + mention = doc[start:end] + sentence = mention.sent + + candidates = self.kb.get_candidates(mention.text) for c in candidates: - prior_prob = c.prior_prob kb_id = c.entity_ - print("candidate", kb_id, prior_prob) - entity_encoding = c.entity_vector - print() + # TODO: currently only training on the positive instances + if kb_id == gold_kb: + prior_prob = c.prior_prob + entity_encoding = c.entity_vector - print() + entity_encodings.append(entity_encoding) + article_docs.append(doc) + sentence_docs.append(sentence.as_doc()) - # entity_encodings = None #TODO - # doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop) - # sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop) - # - # concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in - # range(len(article_docs))] - # mention_encodings, bp_cont = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP) - # - # loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None) - # - # mention_gradient = bp_cont(d_scores, sgd=self.sgd_cont) - # - # # gradient : concat (doc+sent) vs. desc - # sent_start = self.article_encoder.nO - # sent_gradients = list() - # doc_gradients = list() - # for x in mention_gradient: - # doc_gradients.append(list(x[0:sent_start])) - # sent_gradients.append(list(x[sent_start:])) - # - # bp_doc(doc_gradients, sgd=self.sgd_article) - # bp_sent(sent_gradients, sgd=self.sgd_sent) - # - # if losses is not None: - # losses.setdefault(self.name, 0.0) - # losses[self.name] += loss - # return loss - return None + if len(entity_encodings) > 0: + doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop) + sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop) + + concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in + range(len(article_docs))] + mention_encodings, bp_mention = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=drop) + + entity_encodings = np.asarray(entity_encodings, dtype=np.float32) + + loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None) + + mention_gradient = bp_mention(d_scores, sgd=self.sgd_mention) + + # gradient : concat (doc+sent) vs. desc + sent_start = self.article_encoder.nO + sent_gradients = list() + doc_gradients = list() + for x in mention_gradient: + doc_gradients.append(list(x[0:sent_start])) + sent_gradients.append(list(x[sent_start:])) + + bp_doc(doc_gradients, sgd=self.sgd_article) + bp_sent(sent_gradients, sgd=self.sgd_sent) + + if losses is not None: + losses.setdefault(self.name, 0.0) + losses[self.name] += loss + return loss + + return 0 def get_loss(self, docs, golds, scores): loss, gradients = get_cossim_loss(scores, golds)