60% acc run

This commit is contained in:
svlandeg 2019-06-03 08:04:49 +02:00
parent 268a52ead7
commit 9e88763dab
2 changed files with 74 additions and 88 deletions

View File

@ -23,7 +23,6 @@ from thinc.misc import LayerNorm as LN
# from spacy.cli.pretrain import get_cossim_loss
from spacy.matcher import PhraseMatcher
from spacy.tokens import Doc
""" TODO: this code needs to be implemented in pipes.pyx"""
@ -46,7 +45,7 @@ class EL_Model:
DROP = 0.1
LEARN_RATE = 0.001
EPOCHS = 10
EPOCHS = 20
L2 = 1e-6
name = "entity_linker"
@ -211,9 +210,6 @@ class EL_Model:
return acc
def _predict(self, article_doc, sent_doc, desc_docs, avg=True, apply_threshold=True):
# print()
# print("predicting article")
if avg:
with self.article_encoder.use_params(self.sgd_article.averages) \
and self.desc_encoder.use_params(self.sgd_desc.averages)\
@ -228,16 +224,10 @@ class EL_Model:
doc_encoding = self.article_encoder([article_doc])
sent_encoding = self.sent_encoder([sent_doc])
# print("desc_encodings", desc_encodings)
# print("doc_encoding", doc_encoding)
# print("sent_encoding", sent_encoding)
concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])]
# print("concat_encoding", concat_encoding)
cont_encodings = self.cont_encoder(np.asarray([concat_encoding[0]]))
# print("cont_encodings", cont_encodings)
context_enc = np.transpose(cont_encodings)
# print("context_enc", context_enc)
highest_sim = -5
best_i = -1
@ -353,11 +343,11 @@ class EL_Model:
sents_list.append(sent)
descs_list.append(descs[e])
targets.append([1])
else:
arts_list.append(art)
sents_list.append(sent)
descs_list.append(descs[e])
targets.append([-1])
# else:
# arts_list.append(art)
# sents_list.append(sent)
# descs_list.append(descs[e])
# targets.append([-1])
desc_docs = self.nlp.pipe(descs_list)
desc_encodings, bp_desc = self.desc_encoder.begin_update(desc_docs, drop=self.DROP)
@ -372,18 +362,17 @@ class EL_Model:
range(len(targets))]
cont_encodings, bp_cont = self.cont_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP)
# print("sent_encodings", type(sent_encodings), sent_encodings)
# print("desc_encodings", type(desc_encodings), desc_encodings)
# print("doc_encodings", type(doc_encodings), doc_encodings)
# print("getting los for", len(arts_list), "entities")
loss, cont_gradient = self.get_loss(cont_encodings, desc_encodings, targets)
loss, gradient = self.get_loss(cont_encodings, desc_encodings, targets)
# loss, desc_gradient = self.get_loss(desc_encodings, cont_encodings, targets)
# cont_gradient = cont_gradient / 2
# desc_gradient = desc_gradient / 2
# bp_desc(desc_gradient, sgd=self.sgd_desc)
# print("gradient", gradient)
if self.PRINT_BATCH_LOSS:
print("batch loss", loss)
context_gradient = bp_cont(gradient, sgd=self.sgd_cont)
context_gradient = bp_cont(cont_gradient, sgd=self.sgd_cont)
# gradient : concat (doc+sent) vs. desc
sent_start = self.ARTICLE_WIDTH
@ -393,9 +382,6 @@ class EL_Model:
doc_gradients.append(list(x[0:sent_start]))
sent_gradients.append(list(x[sent_start:]))
# print("doc_gradients", doc_gradients)
# print("sent_gradients", sent_gradients)
bp_doc(doc_gradients, sgd=self.sgd_article)
bp_sent(sent_gradients, sgd=self.sgd_sent)
@ -426,74 +412,75 @@ class EL_Model:
article_id = f.replace(".txt", "")
if cnt % 500 == 0 and to_print:
print(datetime.datetime.now(), "processed", cnt, "files in the training dataset")
cnt += 1
# parse the article text
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
text = file.read()
article_doc = self.nlp(text)
truncated_text = text[0:min(self.DOC_CUTOFF, len(text))]
text_by_article[article_id] = truncated_text
try:
# parse the article text
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
text = file.read()
article_doc = self.nlp(text)
truncated_text = text[0:min(self.DOC_CUTOFF, len(text))]
text_by_article[article_id] = truncated_text
# process all positive and negative entities, collect all relevant mentions in this article
for mention, entity_pos in correct_entries[article_id].items():
cluster = article_id + "_" + mention
descr = id_to_descr.get(entity_pos)
entities = set()
if descr:
entity = "E_" + str(next_entity_nr) + "_" + cluster
next_entity_nr += 1
gold_by_entity[entity] = 1
desc_by_entity[entity] = descr
entities.add(entity)
# process all positive and negative entities, collect all relevant mentions in this article
for mention, entity_pos in correct_entries[article_id].items():
cluster = article_id + "_" + mention
descr = id_to_descr.get(entity_pos)
entities = set()
if descr:
entity = "E_" + str(next_entity_nr) + "_" + cluster
next_entity_nr += 1
gold_by_entity[entity] = 1
desc_by_entity[entity] = descr
entities.add(entity)
entity_negs = incorrect_entries[article_id][mention]
for entity_neg in entity_negs:
descr = id_to_descr.get(entity_neg)
if descr:
entity = "E_" + str(next_entity_nr) + "_" + cluster
next_entity_nr += 1
gold_by_entity[entity] = 0
desc_by_entity[entity] = descr
entities.add(entity)
entity_negs = incorrect_entries[article_id][mention]
for entity_neg in entity_negs:
descr = id_to_descr.get(entity_neg)
if descr:
entity = "E_" + str(next_entity_nr) + "_" + cluster
next_entity_nr += 1
gold_by_entity[entity] = 0
desc_by_entity[entity] = descr
entities.add(entity)
found_matches = 0
if len(entities) > 1:
entities_by_cluster[cluster] = entities
found_matches = 0
if len(entities) > 1:
entities_by_cluster[cluster] = entities
# find all matches in the doc for the mentions
# TODO: fix this - doesn't look like all entities are found
matcher = PhraseMatcher(self.nlp.vocab)
patterns = list(self.nlp.tokenizer.pipe([mention]))
# find all matches in the doc for the mentions
# TODO: fix this - doesn't look like all entities are found
matcher = PhraseMatcher(self.nlp.vocab)
patterns = list(self.nlp.tokenizer.pipe([mention]))
matcher.add("TerminologyList", None, *patterns)
matches = matcher(article_doc)
matcher.add("TerminologyList", None, *patterns)
matches = matcher(article_doc)
# store sentences
for match_id, start, end in matches:
span = article_doc[start:end]
if mention == span.text:
found_matches += 1
sent_text = span.sent.text
sent_nr = sentence_by_text.get(sent_text, None)
if sent_nr is None:
sent_nr = "S_" + str(next_sent_nr) + article_id
next_sent_nr += 1
text_by_sentence[sent_nr] = sent_text
sentence_by_text[sent_text] = sent_nr
article_by_cluster[cluster] = article_id
sentence_by_cluster[cluster] = sent_nr
# store sentences
for match_id, start, end in matches:
found_matches += 1
span = article_doc[start:end]
assert mention == span.text
sent_text = span.sent.text
sent_nr = sentence_by_text.get(sent_text, None)
if sent_nr is None:
sent_nr = "S_" + str(next_sent_nr) + article_id
next_sent_nr += 1
text_by_sentence[sent_nr] = sent_text
sentence_by_text[sent_text] = sent_nr
article_by_cluster[cluster] = article_id
sentence_by_cluster[cluster] = sent_nr
if found_matches == 0:
# TODO print("Could not find neg instances or sentence matches for", mention, "in", article_id)
entities_by_cluster.pop(cluster, None)
article_by_cluster.pop(cluster, None)
sentence_by_cluster.pop(cluster, None)
for entity in entities:
gold_by_entity.pop(entity, None)
desc_by_entity.pop(entity, None)
if found_matches == 0:
# print("Could not find neg instances or sentence matches for", mention, "in", article_id)
entities_by_cluster.pop(cluster, None)
article_by_cluster.pop(cluster, None)
sentence_by_cluster.pop(cluster, None)
for entity in entities:
gold_by_entity.pop(entity, None)
desc_by_entity.pop(entity, None)
cnt += 1
except:
print("Problem parsing article", article_id)
if to_print:
print()

View File

@ -111,7 +111,7 @@ if __name__ == "__main__":
print("STEP 6: training", datetime.datetime.now())
my_nlp = spacy.load('en_core_web_md')
trainer = EL_Model(kb=my_kb, nlp=my_nlp)
trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=1000, devlimit=100)
trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=10000, devlimit=500)
print()
# STEP 7: apply the EL algorithm on the dev dataset
@ -120,7 +120,6 @@ if __name__ == "__main__":
run_el.run_el_dev(kb=my_kb, nlp=my_nlp, training_dir=TRAINING_DIR, limit=2000)
print()
# TODO coreference resolution
# add_coref()