diff --git a/utils/post_process_decoded_results.py b/utils/post_process_decoded_results.py index 92b23c29..408ac9de 100644 --- a/utils/post_process_decoded_results.py +++ b/utils/post_process_decoded_results.py @@ -28,6 +28,7 @@ import sys import os import re import argparse +from collections import defaultdict parser = argparse.ArgumentParser() @@ -42,6 +43,14 @@ args = parser.parse_args() def compute_accuracy(pred, gold): return pred == gold +def compute_accuracy_without_params(pred, gold): + pred_list, gold_list = get_quotes(pred, gold) + + pred_cleaned = [pred.replace(val, '') for val in pred_list] + gold_cleaned = [gold.replace(val, '') for val in gold_list] + + return pred_cleaned == gold_cleaned + def compute_grammar_accuracy(pred): return len(pred.split(' ')) != 0 @@ -80,6 +89,7 @@ def compute_correct_quotes(pred, gold): def get_quotes(pred, gold): + quotes_list_pred = [] quotes_list_gold = [] quoted = re.compile('"[^"]*"') @@ -91,7 +101,9 @@ def get_quotes(pred, gold): return quotes_list_pred, quotes_list_gold def find_indices(ref, shuf): - + # models preprocess datasets before training and testing. + # during this procedure the dataset gets shuffled + # this function find a mapping between original ordering of data and shuffled data in the dataset ref_list = [] shuf_list = [] @@ -99,7 +111,6 @@ def find_indices(ref, shuf): for line in f_ref: line = line[:-1].lower() ref_list.append(line) - with open(shuf, 'r') as f_shuf: for line in f_shuf: line = line[1:-2].replace(r'\"', '"').lower() @@ -111,41 +122,80 @@ def find_indices(ref, shuf): return indices -indices = find_indices(args.reference_gold, args.gold_program) -inputs = [] -with open(args.input_sentences, 'r') as input_file: - for line in input_file: - inputs.append(line) +def run(): + indices = find_indices(args.reference_gold, args.gold_program) + inputs = [] + with open(args.input_sentences, 'r') as input_file: + for line in input_file: + inputs.append(line) + res = [inputs[i] for i in indices] - res = [inputs[i] for i in indices] + errors_dev = defaultdict(int) + errors_func = defaultdict(lambda: defaultdict(int)) -with open(args.gold_program, 'r') as gold_file,\ - open(args.predicted_program, 'r') as pred_file,\ - open(args.output_file, 'w') as out: + cnt_dev = 0 + cnt_func = 0 - for line in zip(res, gold_file, pred_file): - input, gold, pred = line - input = input.replace(r'', '') - input = input.strip() - gold = gold.strip() - pred = pred.strip() - accuracy = compute_accuracy(pred, gold) - gramar_accuracy = compute_grammar_accuracy(pred) - function_correctness = compute_funtion_correctness(pred, gold) - device_correctness = compute_device_correctness(pred, gold) - correct_tokens = compute_correct_tokens(pred, gold) - correct_quotes = compute_correct_quotes(pred, gold) + with open(args.gold_program, 'r') as gold_file,\ + open(args.predicted_program, 'r') as pred_file,\ + open(args.output_file, 'w') as out: - out.write(input + ' || ' + gold + ' || ' + pred + ' || ' - + str(accuracy) + ' || ' - + str(gramar_accuracy) + '_grammar' + ' || ' - + str(function_correctness) + '_function' + ' || ' - + str(device_correctness) + '_device') - if correct_quotes != False: - out.write(' || ' + str("{0:.2f}".format(correct_quotes)) + '%_correct_quotes') - if correct_tokens != False: - out.write(' || ' + str("{0:.2f}".format(correct_tokens)) + '%_correct_tokens') - out.write('\n') - out.write('\n') - out.write('\n') + for line in zip(res, gold_file, pred_file): + input, gold, pred = line + input = input.replace(r'', '').strip() + gold = gold.strip() + pred = pred.strip() + accuracy = compute_accuracy(pred, gold) + accuracy_without_params = compute_accuracy_without_params(pred, gold) + grammar_accuracy = compute_grammar_accuracy(pred) + function_correctness = compute_funtion_correctness(pred, gold) + device_correctness = compute_device_correctness(pred, gold) + correct_tokens = compute_correct_tokens(pred, gold) + correct_quotes = compute_correct_quotes(pred, gold) + + ########## + # error analysis + if not device_correctness: + gold_devs = get_devices(gold) + pred_devs = get_devices(pred) + cnt_dev += 1 + if len(gold_devs) == len(pred_devs): + + for i, gold in enumerate(gold_devs): + if gold != pred_devs[i]: + errors_dev[(gold, pred_devs[i])] += 1 + + elif not function_correctness: + gold_funcs = get_functions(gold) + pred_funcs = get_functions(pred) + cnt_func += 1 + if len(gold_funcs) == len(pred_funcs): + devices = get_devices(gold) + for i, device in enumerate(devices): + if gold_funcs[i] != pred_funcs[i]: + errors_func[device][(gold_funcs[i].rsplit('.', 1)[1], pred_funcs[i].rsplit('.', 1)[1])] += 1 + ########## + + out.write(input + ' || ' + gold + ' || ' + pred + ' || ' + + str(accuracy) + ' || ' + + str(accuracy_without_params) + '_w/o_params' + ' || ' + + str(grammar_accuracy) + '_grammar' + ' || ' + + str(function_correctness) + '_function' + ' || ' + + str(device_correctness) + '_device') + if correct_quotes != False: + out.write(' || ' + str("{0:.2f}".format(correct_quotes)) + '%_correct_quotes') + if correct_tokens != False: + out.write(' || ' + str("{0:.2f}".format(correct_tokens)) + '%_correct_tokens') + out.write('\n') + out.write('\n') + out.write('\n') + + + print('cnt_dev: ', cnt_dev) + print('cnt_func: ', cnt_func) + print('errors_dev: ', errors_dev.items()) + print('errors_func: ', errors_func.items()) + +if __name__ == '__main__': + run()