From 66813a1fdcfa2b1f2c9e3af0b8b3922427d1d73a Mon Sep 17 00:00:00 2001 From: svlandeg Date: Tue, 11 Jun 2019 14:18:20 +0200 Subject: [PATCH] speed up predictions --- .../wiki_entity_linking/wiki_nel_pipeline.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py index 6e4ca6970..8753450bb 100644 --- a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py +++ b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py @@ -115,8 +115,8 @@ def run_pipeline(): # STEP 6: create the entity linking pipe if train_pipe: - train_limit = 5 - dev_limit = 2 + train_limit = 100 + dev_limit = 20 print("Training on", train_limit, "articles") print("Dev testing on", dev_limit, "articles") print() @@ -155,22 +155,25 @@ def run_pipeline(): losses=losses, ) + # print(" measuring accuracy 1-1") el_pipe.context_weight = 1 el_pipe.prior_weight = 1 - dev_acc_1_1 = _measure_accuracy(dev_data, nlp) - train_acc_1_1 = _measure_accuracy(train_data, nlp) + dev_acc_1_1 = _measure_accuracy(dev_data, el_pipe) + train_acc_1_1 = _measure_accuracy(train_data, el_pipe) + # print(" measuring accuracy 0-1") el_pipe.context_weight = 0 el_pipe.prior_weight = 1 - dev_acc_0_1 = _measure_accuracy(dev_data, nlp) - train_acc_0_1 = _measure_accuracy(train_data, nlp) + dev_acc_0_1 = _measure_accuracy(dev_data, el_pipe) + train_acc_0_1 = _measure_accuracy(train_data, el_pipe) + # print(" measuring accuracy 1-0") el_pipe.context_weight = 1 el_pipe.prior_weight = 0 - dev_acc_1_0 = _measure_accuracy(dev_data, nlp) - train_acc_1_0 = _measure_accuracy(train_data, nlp) + dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe) + train_acc_1_0 = _measure_accuracy(train_data, el_pipe) - print("Epoch, train loss, train/dev acc, 1-1, 0-1, 1-0:", itn, losses['entity_linker'], + print("Epoch, train loss, train/dev acc, 1-1, 0-1, 1-0:", itn, round(losses['entity_linker'], 2), round(train_acc_1_1, 2), round(train_acc_0_1, 2), round(train_acc_1_0, 2), "/", round(dev_acc_1_1, 2), round(dev_acc_0_1, 2), round(dev_acc_1_0, 2)) @@ -184,12 +187,13 @@ def run_pipeline(): print("STOP", datetime.datetime.now()) -def _measure_accuracy(data, nlp): +def _measure_accuracy(data, el_pipe): correct = 0 incorrect = 0 - texts = [d.text for d, g in data] - docs = list(nlp.pipe(texts)) + docs = [d for d, g in data] + docs = el_pipe.pipe(docs) + golds = [g for d, g in data] for doc, gold in zip(docs, golds):