undersampling negatives

This commit is contained in:
svlandeg 2019-05-21 18:35:10 +02:00
parent 2fa3fac851
commit 7b13e3d56f
2 changed files with 11 additions and 11 deletions

View File

@ -56,20 +56,19 @@ class EL_Model:
entity_descr_output,
False,
trainlimit,
balance=True,
to_print=False)
dev_inst, dev_pos, dev_neg, dev_texts = self._get_training_data(training_dir,
entity_descr_output,
True,
devlimit,
balance=False,
to_print=False)
self._begin_training()
print()
self._test_dev(train_inst, train_pos, train_neg, train_texts, print_string="train_random", calc_random=True)
self._test_dev(dev_inst, dev_pos, dev_neg, dev_texts, print_string="dev_random", calc_random=True)
print()
self._test_dev(train_inst, train_pos, train_neg, train_texts, print_string="train_pre", avg=False)
self._test_dev(dev_inst, dev_pos, dev_neg, dev_texts, print_string="dev_pre", avg=False)
instance_pos_count = 0
@ -120,9 +119,6 @@ class EL_Model:
print()
print("Trained on", instance_pos_count, "/", instance_neg_count, "instances pos/neg")
if self.PRINT_TRAIN:
self._test_dev(train_inst, train_pos, train_neg, train_texts, print_string="train_post_avg", avg=True)
def _test_dev(self, instances, pos, neg, texts_by_id, print_string, avg=False, calc_random=False):
predictions = list()
golds = list()
@ -290,7 +286,7 @@ class EL_Model:
bp_doc([doc_gradient], sgd=self.sgd_article)
bp_entity(entity_gradients, sgd=self.sgd_entity)
def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print):
def _get_training_data(self, training_dir, entity_descr_output, dev, limit, balance, to_print):
id_to_descr = kb_creator._get_id_to_description(entity_descr_output)
correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir,
@ -324,12 +320,16 @@ class EL_Model:
pos_entities[article_id + "_" + mention] = descr
for mention, entity_negs in incorrect_entries[article_id].items():
neg_count = 0
for entity_neg in entity_negs:
descr = id_to_descr.get(entity_neg)
if descr:
descr_list = neg_entities.get(article_id + "_" + mention, [])
descr_list.append(descr)
neg_entities[article_id + "_" + mention] = descr_list
# if balance, keep only 1 negative instance for each positive instance
if neg_count < 1 or not balance:
descr_list = neg_entities.get(article_id + "_" + mention, [])
descr_list.append(descr)
neg_entities[article_id + "_" + mention] = descr_list
neg_count += 1
if to_print:
print()

View File

@ -111,7 +111,7 @@ if __name__ == "__main__":
print("STEP 6: training", datetime.datetime.now())
my_nlp = spacy.load('en_core_web_md')
trainer = EL_Model(kb=my_kb, nlp=my_nlp)
trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=10, devlimit=10)
trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=500, devlimit=20)
print()
# STEP 7: apply the EL algorithm on the dev dataset