113 lines
4.7 KiB
Python
113 lines
4.7 KiB
Python
#
|
|
# Copyright (c) 2018, Salesforce, Inc.
|
|
# All rights reserved.
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
#
|
|
# * Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
#
|
|
# * Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
#
|
|
# * Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
from .tasks.generic_dataset import Query
|
|
from argparse import ArgumentParser
|
|
import os
|
|
import sys
|
|
import ujson as json
|
|
from .metrics import to_lf
|
|
|
|
|
|
def correct_format(x):
|
|
if len(x.keys()) == 0:
|
|
x = {'query': None, 'error': 'Invalid'}
|
|
else:
|
|
c = x['conds']
|
|
proper = True
|
|
for cc in c:
|
|
if len(cc) < 3:
|
|
proper = False
|
|
if proper:
|
|
x = {'query': x, 'error': ''}
|
|
else:
|
|
x = {'query': None, 'error': 'Invalid'}
|
|
return x
|
|
|
|
|
|
def write_logical_forms(greedy, args):
|
|
data_dir = os.path.join(args.data, 'wikisql', 'data')
|
|
path = os.path.join(data_dir, 'dev.jsonl') if 'valid' in args.evaluate else os.path.join(data_dir, 'test.jsonl')
|
|
table_path = os.path.join(data_dir, 'dev.tables.jsonl') if 'valid' in args.evaluate else os.path.join(data_dir, 'test.tables.jsonl')
|
|
with open(table_path) as tables_file:
|
|
tables = [json.loads(line) for line in tables_file]
|
|
id_to_tables = {x['id']: x for x in tables}
|
|
|
|
examples = []
|
|
with open(path) as example_file:
|
|
for line in example_file:
|
|
entry = json.loads(line)
|
|
table = id_to_tables[entry['table_id']]
|
|
sql = entry['sql']
|
|
header = table['header']
|
|
a = repr(Query.from_dict(entry['sql'], table['header']))
|
|
ex = {'sql': sql, 'header': header, 'answer': a, 'table': table}
|
|
examples.append(ex)
|
|
|
|
with open(args.output, 'a') as f:
|
|
count = 0
|
|
correct = 0
|
|
text_answers = []
|
|
for idx, (g, ex) in enumerate(zip(greedy, examples)):
|
|
count += 1
|
|
text_answers.append([ex['answer'].lower()])
|
|
try:
|
|
lf = to_lf(g, ex['table'])
|
|
f.write(json.dumps(correct_format(lf)) + '\n')
|
|
gt = ex['sql']
|
|
conds = gt['conds']
|
|
lower_conds = []
|
|
for c in conds:
|
|
lc = c
|
|
lc[2] = str(lc[2]).lower()
|
|
lower_conds.append(lc)
|
|
gt['conds'] = lower_conds
|
|
correct += lf == gt
|
|
except Exception as e:
|
|
f.write(json.dumps(correct_format({})) + '\n')
|
|
|
|
def main(argv=sys.argv):
|
|
parser = ArgumentParser(prog=argv[0])
|
|
parser.add_argument('data', help='path to the directory containing data for WikiSQL')
|
|
parser.add_argument('predictions', help='path to prediction file, containing one prediction per line')
|
|
parser.add_argument('ids', help='path to file for indices, a list of integers indicating the index into the dev/test set of the predictions on the corresponding line in \'predicitons\'')
|
|
parser.add_argument('output', help='path for logical forms output line by line')
|
|
parser.add_argument('evaluate', help='running on the \'validation\' or \'test\' set')
|
|
args = parser.parse_args(argv[1:])
|
|
with open(args.predictions) as f:
|
|
greedy = [l for l in f]
|
|
if args.ids is not None:
|
|
with open(args.ids) as f:
|
|
ids = [int(l.strip()) for l in f]
|
|
greedy = [x[1] for x in sorted([(i, g) for i, g in zip(ids, greedy)])]
|
|
write_logical_forms(greedy, args)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|