diff --git a/examples/pipeline/wiki_entity_linking/kb_creator.py b/examples/pipeline/wiki_entity_linking/kb_creator.py index 4d7bd646b..80d0e21e9 100644 --- a/examples/pipeline/wiki_entity_linking/kb_creator.py +++ b/examples/pipeline/wiki_entity_linking/kb_creator.py @@ -56,7 +56,7 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, frequency_list.append(freq) filtered_title_to_id[title] = entity - print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()), "titles") + print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()), "titles with filter frequency", min_entity_freq) print() print(" * train entity encoder", datetime.datetime.now()) diff --git a/examples/pipeline/wiki_entity_linking/run_el.py b/examples/pipeline/wiki_entity_linking/run_el.py index 52ccccfda..c26e8d65a 100644 --- a/examples/pipeline/wiki_entity_linking/run_el.py +++ b/examples/pipeline/wiki_entity_linking/run_el.py @@ -25,9 +25,7 @@ def run_kb_toy_example(kb): def run_el_dev(nlp, kb, training_dir, limit=None): - correct_entries_per_article, _ = training_set_creator.read_training_entities(training_output=training_dir, - collect_correct=True, - collect_incorrect=False) + correct_entries_per_article, _ = training_set_creator.read_training_entities(training_output=training_dir) predictions = list() golds = list() diff --git a/examples/pipeline/wiki_entity_linking/train_el.py b/examples/pipeline/wiki_entity_linking/train_el.py index 143e38d99..a4026d935 100644 --- a/examples/pipeline/wiki_entity_linking/train_el.py +++ b/examples/pipeline/wiki_entity_linking/train_el.py @@ -389,9 +389,7 @@ class EL_Model: bp_sent(sent_gradients, sgd=self.sgd_sent) def _get_training_data(self, training_dir, id_to_descr, dev, limit, to_print): - correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir, - collect_correct=True, - collect_incorrect=True) + correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir) entities_by_cluster = dict() gold_by_entity = dict() diff --git a/examples/pipeline/wiki_entity_linking/training_set_creator.py b/examples/pipeline/wiki_entity_linking/training_set_creator.py index 845ce62dc..5d089c620 100644 --- a/examples/pipeline/wiki_entity_linking/training_set_creator.py +++ b/examples/pipeline/wiki_entity_linking/training_set_creator.py @@ -16,12 +16,13 @@ from . import wikipedia_processor as wp, kb_creator Process Wikipedia interlinks to generate a training dataset for the EL algorithm """ -ENTITY_FILE = "gold_entities.csv" +# ENTITY_FILE = "gold_entities.csv" +ENTITY_FILE = "gold_entities_100000.csv" # use this file for faster processing def create_training(entity_def_input, training_output): wp_to_id = kb_creator._get_entity_to_id(entity_def_input) - _process_wikipedia_texts(wp_to_id, training_output, limit=100000000) + _process_wikipedia_texts(wp_to_id, training_output, limit=None) def _process_wikipedia_texts(wp_to_id, training_output, limit=None): @@ -290,75 +291,72 @@ def _write_training_entity(outputfile, article_id, alias, entity, start, end): outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n") -def read_training_entities(training_output): +def is_dev(article_id): + return article_id.endswith("3") + + +def read_training_entities(training_output, dev, limit): entityfile_loc = training_output + "/" + ENTITY_FILE entries_per_article = dict() + article_ids = set() with open(entityfile_loc, mode='r', encoding='utf8') as file: for line in file: - fields = line.replace('\n', "").split(sep='|') - article_id = fields[0] - alias = fields[1] - wp_title = fields[2] - start = fields[3] - end = fields[4] + if not limit or len(article_ids) < limit: + fields = line.replace('\n', "").split(sep='|') + article_id = fields[0] + if dev == is_dev(article_id) and article_id != "article_id": + article_ids.add(article_id) - entries_by_offset = entries_per_article.get(article_id, dict()) - entries_by_offset[start + "-" + end] = (alias, wp_title) + alias = fields[1] + wp_title = fields[2] + start = fields[3] + end = fields[4] - entries_per_article[article_id] = entries_by_offset + entries_by_offset = entries_per_article.get(article_id, dict()) + entries_by_offset[start + "-" + end] = (alias, wp_title) + + entries_per_article[article_id] = entries_by_offset return entries_per_article -def read_training(nlp, training_dir, dev, limit, to_print): - # This method will provide training examples that correspond to the entity annotations found by the nlp object - entries_per_article = read_training_entities(training_output=training_dir) +def read_training(nlp, training_dir, dev, limit): + # This method provides training examples that correspond to the entity annotations found by the nlp object + + print("reading training entities") + entries_per_article = read_training_entities(training_output=training_dir, dev=dev, limit=limit) + print("done reading training entities") data = [] + for article_id, entries_by_offset in entries_per_article.items(): + file_name = article_id + ".txt" + try: + # parse the article text + with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as file: + text = file.read() + article_doc = nlp(text) - cnt = 0 - files = listdir(training_dir) - for f in files: - if not limit or cnt < limit: - if dev == run_el.is_dev(f): - article_id = f.replace(".txt", "") - if cnt % 500 == 0 and to_print: - print(datetime.datetime.now(), "processed", cnt, "files in the training dataset") + gold_entities = list() + for ent in article_doc.ents: + start = ent.start_char + end = ent.end_char - try: - # parse the article text - with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file: - text = file.read() - article_doc = nlp(text) + entity_tuple = entries_by_offset.get(str(start) + "-" + str(end), None) + if entity_tuple: + alias, wp_title = entity_tuple + if ent.text != alias: + print("Non-matching entity in", article_id, start, end) + else: + gold_entities.append((start, end, wp_title)) - entries_by_offset = entries_per_article.get(article_id, dict()) + if gold_entities: + gold = GoldParse(doc=article_doc, links=gold_entities) + data.append((article_doc, gold)) - gold_entities = list() - for ent in article_doc.ents: - start = ent.start_char - end = ent.end_char + except Exception as e: + print("Problem parsing article", article_id) + print(e) + raise e - entity_tuple = entries_by_offset.get(str(start) + "-" + str(end), None) - if entity_tuple: - alias, wp_title = entity_tuple - if ent.text != alias: - print("Non-matching entity in", article_id, start, end) - else: - gold_entities.append((start, end, wp_title)) - - if gold_entities: - gold = GoldParse(doc=article_doc, links=gold_entities) - data.append((article_doc, gold)) - - cnt += 1 - except Exception as e: - print("Problem parsing article", article_id) - print(e) - raise e - - if to_print: - print() - print("Processed", cnt, "training articles, dev=" + str(dev)) - print() return data diff --git a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py index 1e5280f89..b3b3479e2 100644 --- a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py +++ b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py @@ -29,7 +29,7 @@ NLP_2_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_2' TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/' MAX_CANDIDATES = 10 -MIN_ENTITY_FREQ = 200 +MIN_ENTITY_FREQ = 20 MIN_PAIR_OCC = 5 DOC_SENT_CUTOFF = 2 EPOCHS = 10 @@ -47,21 +47,21 @@ def run_pipeline(): # one-time methods to create KB and write to file to_create_prior_probs = False to_create_entity_counts = False - to_create_kb = True + to_create_kb = False # read KB back in from file - to_read_kb = False + to_read_kb = True to_test_kb = False # create training dataset create_wp_training = False # train the EL pipe - train_pipe = False - measure_performance = False + train_pipe = True + measure_performance = True # test the EL pipe on a simple example - to_test_pipeline = False + to_test_pipeline = True # write the NLP object, read back in and test again test_nlp_io = False @@ -135,46 +135,50 @@ def run_pipeline(): print("STEP 6: training Entity Linking pipe", datetime.datetime.now()) train_limit = 10 dev_limit = 2 - print("Training on", train_limit, "articles") - print("Dev testing on", dev_limit, "articles") - print() train_data = training_set_creator.read_training(nlp=nlp_2, training_dir=TRAINING_DIR, dev=False, - limit=train_limit, - to_print=False) + limit=train_limit) + + print("Training on", len(train_data), "articles") + print() + + if not train_data: + print("Did not find any training data") + + else: + for itn in range(EPOCHS): + random.shuffle(train_data) + losses = {} + batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001)) + batchnr = 0 + + with nlp_2.disable_pipes(*other_pipes): + for batch in batches: + try: + docs, golds = zip(*batch) + nlp_2.update( + docs, + golds, + drop=DROPOUT, + losses=losses, + ) + batchnr += 1 + except Exception as e: + print("Error updating batch", e) + + losses['entity_linker'] = losses['entity_linker'] / batchnr + print("Epoch, train loss", itn, round(losses['entity_linker'], 2)) dev_data = training_set_creator.read_training(nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, - limit=dev_limit, - to_print=False) + limit=dev_limit) + print("Dev testing on", len(dev_data), "articles") + print() - for itn in range(EPOCHS): - random.shuffle(train_data) - losses = {} - batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001)) - batchnr = 0 - - with nlp_2.disable_pipes(*other_pipes): - for batch in batches: - try: - docs, golds = zip(*batch) - nlp_2.update( - docs, - golds, - drop=DROPOUT, - losses=losses, - ) - batchnr += 1 - except Exception as e: - print("Error updating batch", e) - - losses['entity_linker'] = losses['entity_linker'] / batchnr - print("Epoch, train loss", itn, round(losses['entity_linker'], 2)) - - if measure_performance: + if len(dev_data) and measure_performance: print() print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now()) print() diff --git a/examples/pipeline/wiki_entity_linking/wikidata_processor.py b/examples/pipeline/wiki_entity_linking/wikidata_processor.py index f6a6cbe23..967849abb 100644 --- a/examples/pipeline/wiki_entity_linking/wikidata_processor.py +++ b/examples/pipeline/wiki_entity_linking/wikidata_processor.py @@ -104,7 +104,7 @@ def read_wikidata_entities_json(limit=None, to_print=False): if lang_aliases: for item in lang_aliases: if to_print: - print("alias (" + lang + "):", item["value"]) + print("alias (" + lang + "):", item["value"]) if to_print: print()