diff --git a/examples/pipeline/wiki_entity_linking/train_el.py b/examples/pipeline/wiki_entity_linking/train_el.py index 2d218ed60..20a5e4428 100644 --- a/examples/pipeline/wiki_entity_linking/train_el.py +++ b/examples/pipeline/wiki_entity_linking/train_el.py @@ -56,20 +56,19 @@ class EL_Model: entity_descr_output, False, trainlimit, + balance=True, to_print=False) dev_inst, dev_pos, dev_neg, dev_texts = self._get_training_data(training_dir, entity_descr_output, True, devlimit, + balance=False, to_print=False) self._begin_training() print() - self._test_dev(train_inst, train_pos, train_neg, train_texts, print_string="train_random", calc_random=True) self._test_dev(dev_inst, dev_pos, dev_neg, dev_texts, print_string="dev_random", calc_random=True) - print() - self._test_dev(train_inst, train_pos, train_neg, train_texts, print_string="train_pre", avg=False) self._test_dev(dev_inst, dev_pos, dev_neg, dev_texts, print_string="dev_pre", avg=False) instance_pos_count = 0 @@ -120,9 +119,6 @@ class EL_Model: print() print("Trained on", instance_pos_count, "/", instance_neg_count, "instances pos/neg") - if self.PRINT_TRAIN: - self._test_dev(train_inst, train_pos, train_neg, train_texts, print_string="train_post_avg", avg=True) - def _test_dev(self, instances, pos, neg, texts_by_id, print_string, avg=False, calc_random=False): predictions = list() golds = list() @@ -290,7 +286,7 @@ class EL_Model: bp_doc([doc_gradient], sgd=self.sgd_article) bp_entity(entity_gradients, sgd=self.sgd_entity) - def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print): + def _get_training_data(self, training_dir, entity_descr_output, dev, limit, balance, to_print): id_to_descr = kb_creator._get_id_to_description(entity_descr_output) correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir, @@ -324,12 +320,16 @@ class EL_Model: pos_entities[article_id + "_" + mention] = descr for mention, entity_negs in incorrect_entries[article_id].items(): + neg_count = 0 for entity_neg in entity_negs: descr = id_to_descr.get(entity_neg) if descr: - descr_list = neg_entities.get(article_id + "_" + mention, []) - descr_list.append(descr) - neg_entities[article_id + "_" + mention] = descr_list + # if balance, keep only 1 negative instance for each positive instance + if neg_count < 1 or not balance: + descr_list = neg_entities.get(article_id + "_" + mention, []) + descr_list.append(descr) + neg_entities[article_id + "_" + mention] = descr_list + neg_count += 1 if to_print: print() diff --git a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py index 23c12bfe6..0927fb394 100644 --- a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py +++ b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py @@ -111,7 +111,7 @@ if __name__ == "__main__": print("STEP 6: training", datetime.datetime.now()) my_nlp = spacy.load('en_core_web_md') trainer = EL_Model(kb=my_kb, nlp=my_nlp) - trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=10, devlimit=10) + trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=500, devlimit=20) print() # STEP 7: apply the EL algorithm on the dev dataset