mirror of https://github.com/explosion/spaCy.git
undersampling negatives
This commit is contained in:
parent
2fa3fac851
commit
7b13e3d56f
|
@ -56,20 +56,19 @@ class EL_Model:
|
|||
entity_descr_output,
|
||||
False,
|
||||
trainlimit,
|
||||
balance=True,
|
||||
to_print=False)
|
||||
|
||||
dev_inst, dev_pos, dev_neg, dev_texts = self._get_training_data(training_dir,
|
||||
entity_descr_output,
|
||||
True,
|
||||
devlimit,
|
||||
balance=False,
|
||||
to_print=False)
|
||||
self._begin_training()
|
||||
|
||||
print()
|
||||
self._test_dev(train_inst, train_pos, train_neg, train_texts, print_string="train_random", calc_random=True)
|
||||
self._test_dev(dev_inst, dev_pos, dev_neg, dev_texts, print_string="dev_random", calc_random=True)
|
||||
print()
|
||||
self._test_dev(train_inst, train_pos, train_neg, train_texts, print_string="train_pre", avg=False)
|
||||
self._test_dev(dev_inst, dev_pos, dev_neg, dev_texts, print_string="dev_pre", avg=False)
|
||||
|
||||
instance_pos_count = 0
|
||||
|
@ -120,9 +119,6 @@ class EL_Model:
|
|||
print()
|
||||
print("Trained on", instance_pos_count, "/", instance_neg_count, "instances pos/neg")
|
||||
|
||||
if self.PRINT_TRAIN:
|
||||
self._test_dev(train_inst, train_pos, train_neg, train_texts, print_string="train_post_avg", avg=True)
|
||||
|
||||
def _test_dev(self, instances, pos, neg, texts_by_id, print_string, avg=False, calc_random=False):
|
||||
predictions = list()
|
||||
golds = list()
|
||||
|
@ -290,7 +286,7 @@ class EL_Model:
|
|||
bp_doc([doc_gradient], sgd=self.sgd_article)
|
||||
bp_entity(entity_gradients, sgd=self.sgd_entity)
|
||||
|
||||
def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print):
|
||||
def _get_training_data(self, training_dir, entity_descr_output, dev, limit, balance, to_print):
|
||||
id_to_descr = kb_creator._get_id_to_description(entity_descr_output)
|
||||
|
||||
correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir,
|
||||
|
@ -324,12 +320,16 @@ class EL_Model:
|
|||
pos_entities[article_id + "_" + mention] = descr
|
||||
|
||||
for mention, entity_negs in incorrect_entries[article_id].items():
|
||||
neg_count = 0
|
||||
for entity_neg in entity_negs:
|
||||
descr = id_to_descr.get(entity_neg)
|
||||
if descr:
|
||||
descr_list = neg_entities.get(article_id + "_" + mention, [])
|
||||
descr_list.append(descr)
|
||||
neg_entities[article_id + "_" + mention] = descr_list
|
||||
# if balance, keep only 1 negative instance for each positive instance
|
||||
if neg_count < 1 or not balance:
|
||||
descr_list = neg_entities.get(article_id + "_" + mention, [])
|
||||
descr_list.append(descr)
|
||||
neg_entities[article_id + "_" + mention] = descr_list
|
||||
neg_count += 1
|
||||
|
||||
if to_print:
|
||||
print()
|
||||
|
|
|
@ -111,7 +111,7 @@ if __name__ == "__main__":
|
|||
print("STEP 6: training", datetime.datetime.now())
|
||||
my_nlp = spacy.load('en_core_web_md')
|
||||
trainer = EL_Model(kb=my_kb, nlp=my_nlp)
|
||||
trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=10, devlimit=10)
|
||||
trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=500, devlimit=20)
|
||||
print()
|
||||
|
||||
# STEP 7: apply the EL algorithm on the dev dataset
|
||||
|
|
Loading…
Reference in New Issue