From 96413a2a2818c1625a539c296222d94796765568 Mon Sep 17 00:00:00 2001 From: Bryan Marcus McCann Date: Thu, 9 Aug 2018 19:35:36 +0000 Subject: [PATCH] adding script for external wikisql evaluation --- convert_to_logical_forms.py | 81 +++++++++++++++++++++++++++++++++++++ predict.py | 19 +++++++-- 2 files changed, 96 insertions(+), 4 deletions(-) create mode 100644 convert_to_logical_forms.py diff --git a/convert_to_logical_forms.py b/convert_to_logical_forms.py new file mode 100644 index 00000000..d8102273 --- /dev/null +++ b/convert_to_logical_forms.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +from text.torchtext.datasets.generic import Query +from argparse import ArgumentParser +import os +import re +import ujson as json +from metrics import to_lf + + +def correct_format(x): + if len(x.keys()) == 0: + x = {'query': None, 'error': 'Invalid'} + else: + c = x['conds'] + proper = True + for cc in c: + if len(cc) < 3: + proper = False + if proper: + x = {'query': x, 'error': ''} + else: + x = {'query': None, 'error': 'Invalid'} + return x + + +def write_logical_forms(greedy, args): + data_dir = os.path.join(args.data, 'wikisql', 'data') + path = os.path.join(data_dir, 'dev.jsonl') if 'valid' in args.evaluate else os.path.join(data_dir, 'test.jsonl') + table_path = os.path.join(data_dir, 'dev.tables.jsonl') if 'valid' in args.evaluate else os.path.join(data_dir, 'test.tables.jsonl') + with open(table_path) as tables_file: + tables = [json.loads(line) for line in tables_file] + id_to_tables = {x['id']: x for x in tables} + + examples = [] + with open(path) as example_file: + for line in example_file: + entry = json.loads(line) + table = id_to_tables[entry['table_id']] + sql = entry['sql'] + header = table['header'] + a = repr(Query.from_dict(entry['sql'], table['header'])) + ex = {'sql': sql, 'header': header, 'answer': a, 'table': table} + examples.append(ex) + + with open(args.output, 'a') as f: + count = 0 + correct = 0 + text_answers = [] + for idx, (g, ex) in enumerate(zip(greedy, examples)): + count += 1 + text_answers.append([ex['answer'].lower()]) + try: + lf = to_lf(g, ex['table']) + f.write(json.dumps(correct_format(lf)) + '\n') + gt = ex['sql'] + conds = gt['conds'] + lower_conds = [] + for c in conds: + lc = c + lc[2] = str(lc[2]).lower() + lower_conds.append(lc) + gt['conds'] = lower_conds + correct += lf == gt + except Exception as e: + f.write(json.dumps(correct_format({})) + '\n') + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('data', help='path to the directory containing data for WikiSQL') + parser.add_argument('predictions', help='path to prediction file, containing one prediction per line') + parser.add_argument('output', help='path for logical forms output line by line') + parser.add_argument('evaluate' help='running on the \'validation\' or \'test\' set') + parser.add_argument('--ids', help='path to file for indices, a list of integers indicating the index into the dev/test set of the predictions on the corresponding line in \'predicitons\'') + args = parser.parse_args() + with open(args.predictions) as f: + greedy = [l for l in f] + if args.ids is not None: + with open(args.ids) as f: + ids = [int(l.strip()) for l in f] + greedy = [x[1] for x in sorted([(i, g) for i, g in zip(ids, greedy)])] + write_logical_forms(greedy, args) diff --git a/predict.py b/predict.py index d5fce243..541ed2d9 100644 --- a/predict.py +++ b/predict.py @@ -85,6 +85,8 @@ def run(args, field, val_sets, model): prediction_file_name = os.path.join(os.path.splitext(args.best_checkpoint)[0], args.evaluate, task + '.txt') answer_file_name = os.path.join(os.path.splitext(args.best_checkpoint)[0], args.evaluate, task + '.gold.txt') results_file_name = answer_file_name.replace('gold', 'results') + if 'sql' in task: + ids_file_name = answer_file_name.replace('gold', 'ids') if os.path.exists(prediction_file_name): print('** ', prediction_file_name, ' already exists**') if os.path.exists(answer_file_name): @@ -94,7 +96,8 @@ def run(args, field, val_sets, model): with open(results_file_name) as results_file: for l in results_file: print(l) - continue + if not 'schema' in task: + continue for x in [prediction_file_name, answer_file_name, results_file_name]: os.makedirs(os.path.dirname(x), exist_ok=True) @@ -104,13 +107,21 @@ def run(args, field, val_sets, model): for batch_idx, batch in enumerate(it): _, p = model(batch) p = field.reverse(p) - for pp in p: + for i, pp in enumerate(p): + if 'sql' in task: + wikisql_id = int(batch.wikisql_id[i]) + wikisql_ids.append(wikisql_id) prediction_file.write(pp + '\n') predictions.append(pp) else: with open(prediction_file_name) as prediction_file: predictions = [x.strip() for x in prediction_file.readlines()] + if 'sql' in task: + with open(ids_file_name, 'w') as id_file: + for i in wikisql_ids: + id_file.write(json.dumps(i) + '\n') + def from_all_answers(an): return [it.dataset.all_answers[sid] for sid in an.tolist()] @@ -146,7 +157,7 @@ def get_args(): parser = ArgumentParser() parser.add_argument('--path', required=True) parser.add_argument('--evaluate', type=str, required=True) - parser.add_argument('--tasks', default=['wikisql', 'woz.en', 'cnn_dailymail', 'iwslt.en.de', 'zre', 'srl', 'squad', 'sst', 'multinli.in.out'], nargs='+') + parser.add_argument('--tasks', default=['wikisql', 'woz.en', 'cnn_dailymail', 'iwslt.en.de', 'zre', 'srl', 'squad', 'sst', 'multinli.in.out', 'schema'], nargs='+') parser.add_argument('--gpus', type=int, help='gpus to use', required=True) parser.add_argument('--seed', default=123, type=int, help='Random seed.') parser.add_argument('--data', default='/decaNLP/.data/', type=str, help='where to load data from.') @@ -178,7 +189,7 @@ def get_args(): 'schema': 'em'} if os.path.exists(os.path.join(args.path, 'process_0.log')): - args.best_checkpoint = get_best(args, lines) + args.best_checkpoint = get_best(args) else: args.best_checkpoint = os.path.join(args.path, args.checkpoint_name)