update error analysis
This commit is contained in:
parent
787b0887dc
commit
a9b48d2b97
|
@ -28,6 +28,7 @@ import sys
|
|||
import os
|
||||
import re
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
|
@ -42,6 +43,14 @@ args = parser.parse_args()
|
|||
def compute_accuracy(pred, gold):
|
||||
return pred == gold
|
||||
|
||||
def compute_accuracy_without_params(pred, gold):
|
||||
pred_list, gold_list = get_quotes(pred, gold)
|
||||
|
||||
pred_cleaned = [pred.replace(val, '') for val in pred_list]
|
||||
gold_cleaned = [gold.replace(val, '') for val in gold_list]
|
||||
|
||||
return pred_cleaned == gold_cleaned
|
||||
|
||||
def compute_grammar_accuracy(pred):
|
||||
return len(pred.split(' ')) != 0
|
||||
|
||||
|
@ -80,6 +89,7 @@ def compute_correct_quotes(pred, gold):
|
|||
|
||||
|
||||
def get_quotes(pred, gold):
|
||||
|
||||
quotes_list_pred = []
|
||||
quotes_list_gold = []
|
||||
quoted = re.compile('"[^"]*"')
|
||||
|
@ -91,7 +101,9 @@ def get_quotes(pred, gold):
|
|||
return quotes_list_pred, quotes_list_gold
|
||||
|
||||
def find_indices(ref, shuf):
|
||||
|
||||
# models preprocess datasets before training and testing.
|
||||
# during this procedure the dataset gets shuffled
|
||||
# this function find a mapping between original ordering of data and shuffled data in the dataset
|
||||
ref_list = []
|
||||
shuf_list = []
|
||||
|
||||
|
@ -99,7 +111,6 @@ def find_indices(ref, shuf):
|
|||
for line in f_ref:
|
||||
line = line[:-1].lower()
|
||||
ref_list.append(line)
|
||||
|
||||
with open(shuf, 'r') as f_shuf:
|
||||
for line in f_shuf:
|
||||
line = line[1:-2].replace(r'\"', '"').lower()
|
||||
|
@ -111,41 +122,80 @@ def find_indices(ref, shuf):
|
|||
|
||||
return indices
|
||||
|
||||
indices = find_indices(args.reference_gold, args.gold_program)
|
||||
|
||||
inputs = []
|
||||
with open(args.input_sentences, 'r') as input_file:
|
||||
for line in input_file:
|
||||
inputs.append(line)
|
||||
def run():
|
||||
indices = find_indices(args.reference_gold, args.gold_program)
|
||||
inputs = []
|
||||
with open(args.input_sentences, 'r') as input_file:
|
||||
for line in input_file:
|
||||
inputs.append(line)
|
||||
res = [inputs[i] for i in indices]
|
||||
|
||||
res = [inputs[i] for i in indices]
|
||||
errors_dev = defaultdict(int)
|
||||
errors_func = defaultdict(lambda: defaultdict(int))
|
||||
|
||||
with open(args.gold_program, 'r') as gold_file,\
|
||||
open(args.predicted_program, 'r') as pred_file,\
|
||||
open(args.output_file, 'w') as out:
|
||||
cnt_dev = 0
|
||||
cnt_func = 0
|
||||
|
||||
for line in zip(res, gold_file, pred_file):
|
||||
input, gold, pred = line
|
||||
input = input.replace(r'<s>', '')
|
||||
input = input.strip()
|
||||
gold = gold.strip()
|
||||
pred = pred.strip()
|
||||
accuracy = compute_accuracy(pred, gold)
|
||||
gramar_accuracy = compute_grammar_accuracy(pred)
|
||||
function_correctness = compute_funtion_correctness(pred, gold)
|
||||
device_correctness = compute_device_correctness(pred, gold)
|
||||
correct_tokens = compute_correct_tokens(pred, gold)
|
||||
correct_quotes = compute_correct_quotes(pred, gold)
|
||||
with open(args.gold_program, 'r') as gold_file,\
|
||||
open(args.predicted_program, 'r') as pred_file,\
|
||||
open(args.output_file, 'w') as out:
|
||||
|
||||
out.write(input + ' || ' + gold + ' || ' + pred + ' || '
|
||||
+ str(accuracy) + ' || '
|
||||
+ str(gramar_accuracy) + '_grammar' + ' || '
|
||||
+ str(function_correctness) + '_function' + ' || '
|
||||
+ str(device_correctness) + '_device')
|
||||
if correct_quotes != False:
|
||||
out.write(' || ' + str("{0:.2f}".format(correct_quotes)) + '%_correct_quotes')
|
||||
if correct_tokens != False:
|
||||
out.write(' || ' + str("{0:.2f}".format(correct_tokens)) + '%_correct_tokens')
|
||||
out.write('\n')
|
||||
out.write('\n')
|
||||
out.write('\n')
|
||||
for line in zip(res, gold_file, pred_file):
|
||||
input, gold, pred = line
|
||||
input = input.replace(r'<s>', '').strip()
|
||||
gold = gold.strip()
|
||||
pred = pred.strip()
|
||||
accuracy = compute_accuracy(pred, gold)
|
||||
accuracy_without_params = compute_accuracy_without_params(pred, gold)
|
||||
grammar_accuracy = compute_grammar_accuracy(pred)
|
||||
function_correctness = compute_funtion_correctness(pred, gold)
|
||||
device_correctness = compute_device_correctness(pred, gold)
|
||||
correct_tokens = compute_correct_tokens(pred, gold)
|
||||
correct_quotes = compute_correct_quotes(pred, gold)
|
||||
|
||||
##########
|
||||
# error analysis
|
||||
if not device_correctness:
|
||||
gold_devs = get_devices(gold)
|
||||
pred_devs = get_devices(pred)
|
||||
cnt_dev += 1
|
||||
if len(gold_devs) == len(pred_devs):
|
||||
|
||||
for i, gold in enumerate(gold_devs):
|
||||
if gold != pred_devs[i]:
|
||||
errors_dev[(gold, pred_devs[i])] += 1
|
||||
|
||||
elif not function_correctness:
|
||||
gold_funcs = get_functions(gold)
|
||||
pred_funcs = get_functions(pred)
|
||||
cnt_func += 1
|
||||
if len(gold_funcs) == len(pred_funcs):
|
||||
devices = get_devices(gold)
|
||||
for i, device in enumerate(devices):
|
||||
if gold_funcs[i] != pred_funcs[i]:
|
||||
errors_func[device][(gold_funcs[i].rsplit('.', 1)[1], pred_funcs[i].rsplit('.', 1)[1])] += 1
|
||||
##########
|
||||
|
||||
out.write(input + ' || ' + gold + ' || ' + pred + ' || '
|
||||
+ str(accuracy) + ' || '
|
||||
+ str(accuracy_without_params) + '_w/o_params' + ' || '
|
||||
+ str(grammar_accuracy) + '_grammar' + ' || '
|
||||
+ str(function_correctness) + '_function' + ' || '
|
||||
+ str(device_correctness) + '_device')
|
||||
if correct_quotes != False:
|
||||
out.write(' || ' + str("{0:.2f}".format(correct_quotes)) + '%_correct_quotes')
|
||||
if correct_tokens != False:
|
||||
out.write(' || ' + str("{0:.2f}".format(correct_tokens)) + '%_correct_tokens')
|
||||
out.write('\n')
|
||||
out.write('\n')
|
||||
out.write('\n')
|
||||
|
||||
|
||||
print('cnt_dev: ', cnt_dev)
|
||||
print('cnt_func: ', cnt_func)
|
||||
print('errors_dev: ', errors_dev.items())
|
||||
print('errors_func: ', errors_func.items())
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
|
Loading…
Reference in New Issue