adding script for external wikisql evaluation
This commit is contained in:
parent
6a5704e1a4
commit
96413a2a28
|
@ -0,0 +1,81 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
from text.torchtext.datasets.generic import Query
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import ujson as json
|
||||||
|
from metrics import to_lf
|
||||||
|
|
||||||
|
|
||||||
|
def correct_format(x):
|
||||||
|
if len(x.keys()) == 0:
|
||||||
|
x = {'query': None, 'error': 'Invalid'}
|
||||||
|
else:
|
||||||
|
c = x['conds']
|
||||||
|
proper = True
|
||||||
|
for cc in c:
|
||||||
|
if len(cc) < 3:
|
||||||
|
proper = False
|
||||||
|
if proper:
|
||||||
|
x = {'query': x, 'error': ''}
|
||||||
|
else:
|
||||||
|
x = {'query': None, 'error': 'Invalid'}
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def write_logical_forms(greedy, args):
|
||||||
|
data_dir = os.path.join(args.data, 'wikisql', 'data')
|
||||||
|
path = os.path.join(data_dir, 'dev.jsonl') if 'valid' in args.evaluate else os.path.join(data_dir, 'test.jsonl')
|
||||||
|
table_path = os.path.join(data_dir, 'dev.tables.jsonl') if 'valid' in args.evaluate else os.path.join(data_dir, 'test.tables.jsonl')
|
||||||
|
with open(table_path) as tables_file:
|
||||||
|
tables = [json.loads(line) for line in tables_file]
|
||||||
|
id_to_tables = {x['id']: x for x in tables}
|
||||||
|
|
||||||
|
examples = []
|
||||||
|
with open(path) as example_file:
|
||||||
|
for line in example_file:
|
||||||
|
entry = json.loads(line)
|
||||||
|
table = id_to_tables[entry['table_id']]
|
||||||
|
sql = entry['sql']
|
||||||
|
header = table['header']
|
||||||
|
a = repr(Query.from_dict(entry['sql'], table['header']))
|
||||||
|
ex = {'sql': sql, 'header': header, 'answer': a, 'table': table}
|
||||||
|
examples.append(ex)
|
||||||
|
|
||||||
|
with open(args.output, 'a') as f:
|
||||||
|
count = 0
|
||||||
|
correct = 0
|
||||||
|
text_answers = []
|
||||||
|
for idx, (g, ex) in enumerate(zip(greedy, examples)):
|
||||||
|
count += 1
|
||||||
|
text_answers.append([ex['answer'].lower()])
|
||||||
|
try:
|
||||||
|
lf = to_lf(g, ex['table'])
|
||||||
|
f.write(json.dumps(correct_format(lf)) + '\n')
|
||||||
|
gt = ex['sql']
|
||||||
|
conds = gt['conds']
|
||||||
|
lower_conds = []
|
||||||
|
for c in conds:
|
||||||
|
lc = c
|
||||||
|
lc[2] = str(lc[2]).lower()
|
||||||
|
lower_conds.append(lc)
|
||||||
|
gt['conds'] = lower_conds
|
||||||
|
correct += lf == gt
|
||||||
|
except Exception as e:
|
||||||
|
f.write(json.dumps(correct_format({})) + '\n')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument('data', help='path to the directory containing data for WikiSQL')
|
||||||
|
parser.add_argument('predictions', help='path to prediction file, containing one prediction per line')
|
||||||
|
parser.add_argument('output', help='path for logical forms output line by line')
|
||||||
|
parser.add_argument('evaluate' help='running on the \'validation\' or \'test\' set')
|
||||||
|
parser.add_argument('--ids', help='path to file for indices, a list of integers indicating the index into the dev/test set of the predictions on the corresponding line in \'predicitons\'')
|
||||||
|
args = parser.parse_args()
|
||||||
|
with open(args.predictions) as f:
|
||||||
|
greedy = [l for l in f]
|
||||||
|
if args.ids is not None:
|
||||||
|
with open(args.ids) as f:
|
||||||
|
ids = [int(l.strip()) for l in f]
|
||||||
|
greedy = [x[1] for x in sorted([(i, g) for i, g in zip(ids, greedy)])]
|
||||||
|
write_logical_forms(greedy, args)
|
19
predict.py
19
predict.py
|
@ -85,6 +85,8 @@ def run(args, field, val_sets, model):
|
||||||
prediction_file_name = os.path.join(os.path.splitext(args.best_checkpoint)[0], args.evaluate, task + '.txt')
|
prediction_file_name = os.path.join(os.path.splitext(args.best_checkpoint)[0], args.evaluate, task + '.txt')
|
||||||
answer_file_name = os.path.join(os.path.splitext(args.best_checkpoint)[0], args.evaluate, task + '.gold.txt')
|
answer_file_name = os.path.join(os.path.splitext(args.best_checkpoint)[0], args.evaluate, task + '.gold.txt')
|
||||||
results_file_name = answer_file_name.replace('gold', 'results')
|
results_file_name = answer_file_name.replace('gold', 'results')
|
||||||
|
if 'sql' in task:
|
||||||
|
ids_file_name = answer_file_name.replace('gold', 'ids')
|
||||||
if os.path.exists(prediction_file_name):
|
if os.path.exists(prediction_file_name):
|
||||||
print('** ', prediction_file_name, ' already exists**')
|
print('** ', prediction_file_name, ' already exists**')
|
||||||
if os.path.exists(answer_file_name):
|
if os.path.exists(answer_file_name):
|
||||||
|
@ -94,7 +96,8 @@ def run(args, field, val_sets, model):
|
||||||
with open(results_file_name) as results_file:
|
with open(results_file_name) as results_file:
|
||||||
for l in results_file:
|
for l in results_file:
|
||||||
print(l)
|
print(l)
|
||||||
continue
|
if not 'schema' in task:
|
||||||
|
continue
|
||||||
for x in [prediction_file_name, answer_file_name, results_file_name]:
|
for x in [prediction_file_name, answer_file_name, results_file_name]:
|
||||||
os.makedirs(os.path.dirname(x), exist_ok=True)
|
os.makedirs(os.path.dirname(x), exist_ok=True)
|
||||||
|
|
||||||
|
@ -104,13 +107,21 @@ def run(args, field, val_sets, model):
|
||||||
for batch_idx, batch in enumerate(it):
|
for batch_idx, batch in enumerate(it):
|
||||||
_, p = model(batch)
|
_, p = model(batch)
|
||||||
p = field.reverse(p)
|
p = field.reverse(p)
|
||||||
for pp in p:
|
for i, pp in enumerate(p):
|
||||||
|
if 'sql' in task:
|
||||||
|
wikisql_id = int(batch.wikisql_id[i])
|
||||||
|
wikisql_ids.append(wikisql_id)
|
||||||
prediction_file.write(pp + '\n')
|
prediction_file.write(pp + '\n')
|
||||||
predictions.append(pp)
|
predictions.append(pp)
|
||||||
else:
|
else:
|
||||||
with open(prediction_file_name) as prediction_file:
|
with open(prediction_file_name) as prediction_file:
|
||||||
predictions = [x.strip() for x in prediction_file.readlines()]
|
predictions = [x.strip() for x in prediction_file.readlines()]
|
||||||
|
|
||||||
|
if 'sql' in task:
|
||||||
|
with open(ids_file_name, 'w') as id_file:
|
||||||
|
for i in wikisql_ids:
|
||||||
|
id_file.write(json.dumps(i) + '\n')
|
||||||
|
|
||||||
def from_all_answers(an):
|
def from_all_answers(an):
|
||||||
return [it.dataset.all_answers[sid] for sid in an.tolist()]
|
return [it.dataset.all_answers[sid] for sid in an.tolist()]
|
||||||
|
|
||||||
|
@ -146,7 +157,7 @@ def get_args():
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument('--path', required=True)
|
parser.add_argument('--path', required=True)
|
||||||
parser.add_argument('--evaluate', type=str, required=True)
|
parser.add_argument('--evaluate', type=str, required=True)
|
||||||
parser.add_argument('--tasks', default=['wikisql', 'woz.en', 'cnn_dailymail', 'iwslt.en.de', 'zre', 'srl', 'squad', 'sst', 'multinli.in.out'], nargs='+')
|
parser.add_argument('--tasks', default=['wikisql', 'woz.en', 'cnn_dailymail', 'iwslt.en.de', 'zre', 'srl', 'squad', 'sst', 'multinli.in.out', 'schema'], nargs='+')
|
||||||
parser.add_argument('--gpus', type=int, help='gpus to use', required=True)
|
parser.add_argument('--gpus', type=int, help='gpus to use', required=True)
|
||||||
parser.add_argument('--seed', default=123, type=int, help='Random seed.')
|
parser.add_argument('--seed', default=123, type=int, help='Random seed.')
|
||||||
parser.add_argument('--data', default='/decaNLP/.data/', type=str, help='where to load data from.')
|
parser.add_argument('--data', default='/decaNLP/.data/', type=str, help='where to load data from.')
|
||||||
|
@ -178,7 +189,7 @@ def get_args():
|
||||||
'schema': 'em'}
|
'schema': 'em'}
|
||||||
|
|
||||||
if os.path.exists(os.path.join(args.path, 'process_0.log')):
|
if os.path.exists(os.path.join(args.path, 'process_0.log')):
|
||||||
args.best_checkpoint = get_best(args, lines)
|
args.best_checkpoint = get_best(args)
|
||||||
else:
|
else:
|
||||||
args.best_checkpoint = os.path.join(args.path, args.checkpoint_name)
|
args.best_checkpoint = os.path.join(args.path, args.checkpoint_name)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue