422 lines
15 KiB
Python
422 lines
15 KiB
Python
from subprocess import Popen, PIPE, CalledProcessError
|
|
import json
|
|
from text.torchtext.datasets.generic import Query
|
|
import logging
|
|
import os
|
|
import re
|
|
import string
|
|
import collections
|
|
from multiprocessing import Pool, cpu_count
|
|
from contextlib import closing
|
|
|
|
from pyrouge import Rouge155
|
|
from sacrebleu import corpus_bleu
|
|
|
|
|
|
|
|
def to_lf(s, table):
|
|
aggs = [y.lower() for y in Query.agg_ops]
|
|
agg_to_idx = {x: i for i, x in enumerate(aggs)}
|
|
conditionals = [y.lower() for y in Query.cond_ops]
|
|
headers_unsorted = [(y.lower(), i) for i, y in enumerate(table['header'])]
|
|
headers = [(y.lower(), i) for i, y in enumerate(table['header'])]
|
|
headers.sort(reverse=True, key=lambda x: len(x[0]))
|
|
condition_s, conds = None, []
|
|
if 'where' in s:
|
|
s, condition_s = s.split('where', 1)
|
|
|
|
s = ' '.join(s.split()[1:-2])
|
|
sel, agg = None, 0
|
|
for col, idx in headers:
|
|
if col == s:
|
|
sel = idx
|
|
if sel is None:
|
|
s = s.split()
|
|
agg = agg_to_idx[s[0]]
|
|
s = ' '.join(s[1:])
|
|
for col, idx in headers:
|
|
if col == s:
|
|
sel = idx
|
|
|
|
full_conditions = []
|
|
if not condition_s is None:
|
|
|
|
condition_s = ' ' + condition_s + ' '
|
|
for idx, col in enumerate(headers):
|
|
condition_s = condition_s.replace(' ' + col[0] + ' ', ' Col{} '.format(col[1]))
|
|
condition_s = condition_s.strip()
|
|
|
|
for idx, col in enumerate(conditionals):
|
|
new_s = []
|
|
for t in condition_s.split():
|
|
if t == col:
|
|
new_s.append('Cond{}'.format(idx))
|
|
else:
|
|
new_s.append(t)
|
|
condition_s = ' '.join(new_s)
|
|
s = condition_s
|
|
conds = re.split('(Col\d+ Cond\d+)', s)
|
|
if len(conds) == 0:
|
|
conds = [s]
|
|
conds = [x for x in conds if len(x.strip()) > 0]
|
|
full_conditions = []
|
|
for i, x in enumerate(conds):
|
|
if i % 2 == 0:
|
|
x = x.split()
|
|
col_num = int(x[0].replace('Col', ''))
|
|
opp_num = int(x[1].replace('Cond', ''))
|
|
full_conditions.append([col_num, opp_num])
|
|
else:
|
|
x = x.split()
|
|
if x[-1] == 'and':
|
|
x = x[:-1]
|
|
x = ' '.join(x)
|
|
if 'Col' in x:
|
|
new_x = []
|
|
for t in x.split():
|
|
if 'Col' in t:
|
|
idx = int(t.replace('Col', ''))
|
|
t = headers_unsorted[idx][0]
|
|
new_x.append(t)
|
|
x = new_x
|
|
x = ' '.join(x)
|
|
if 'Cond' in x:
|
|
new_x = []
|
|
for t in x.split():
|
|
if 'Cond' in t:
|
|
idx = int(t.replace('Cond', ''))
|
|
t = conditionals[idx]
|
|
new_x.append(t)
|
|
x = new_x
|
|
x = ' '.join(x)
|
|
full_conditions[-1].append(x)
|
|
logical_form = {'sel': sel, 'conds': full_conditions, 'agg': agg}
|
|
return logical_form
|
|
|
|
|
|
def computeLFEM(greedy, answer, args):
|
|
answer = [x[0] for x in answer]
|
|
count = 0
|
|
correct = 0
|
|
text_answers = []
|
|
for idx, (g, ex) in enumerate(zip(greedy, answer)):
|
|
count += 1
|
|
text_answers.append([ex['answer'].lower()])
|
|
try:
|
|
lf = to_lf(g, ex['table'])
|
|
gt = ex['sql']
|
|
conds = gt['conds']
|
|
lower_conds = []
|
|
for c in conds:
|
|
lc = c
|
|
lc[2] = str(lc[2]).lower()
|
|
lower_conds.append(lc)
|
|
gt['conds'] = lower_conds
|
|
correct += lf == gt
|
|
except Exception as e:
|
|
continue
|
|
return correct / count * 100, text_answers
|
|
|
|
|
|
def computeCF1(greedy, answer):
|
|
def remove_and(text):
|
|
return re.sub(r'\b(and)\b', ' ', text)
|
|
greedy_counters = []
|
|
num_not_null_greedy = 0
|
|
for g in greedy:
|
|
clean_g = remove_and(g).split()
|
|
if (len(clean_g) == 1) and ('unanswerable' == clean_g[0]):
|
|
greedy_counters.append(None)
|
|
else:
|
|
greedy_counters.append(collections.Counter(clean_g))
|
|
num_not_null_greedy += 1
|
|
answer_counters = []
|
|
num_not_null_answer = 0
|
|
for aa in answer:
|
|
a_counters = []
|
|
num_not_null_a = 0
|
|
for a in aa:
|
|
clean_a = remove_and(a).split()
|
|
if (len(clean_g) == 1) and ('unanswerable' == clean_a[0]):
|
|
a_counters.append(None)
|
|
else:
|
|
a_counters.append(collections.Counter(clean_a))
|
|
num_not_null_a += 1
|
|
if num_not_null_a > 0:
|
|
num_not_null_answer += 1
|
|
answer_counters.append(a_counters)
|
|
num_true_positive = 0
|
|
for g, aa in zip(greedy_counters, answer_counters):
|
|
if g == None:
|
|
continue
|
|
for a in aa:
|
|
if a == None:
|
|
continue
|
|
elif a == g:
|
|
num_true_positive += 1
|
|
break
|
|
precision = num_true_positive / num_not_null_greedy
|
|
recall = num_true_positive / num_not_null_answer
|
|
return 2 * (precision * recall) / (precision + recall) * 100, precision * 100, recall * 100
|
|
|
|
def normalize_text(s):
|
|
"""Lower text and remove punctuation, articles and extra whitespace."""
|
|
def remove_articles(text):
|
|
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
|
def white_space_fix(text):
|
|
return ' '.join(text.split())
|
|
def remove_punc(text):
|
|
exclude = set(string.punctuation)
|
|
return ''.join(ch for ch in text if ch not in exclude)
|
|
def lower(text):
|
|
return text.lower()
|
|
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
|
|
|
def f1_score(prediction, ground_truth):
|
|
prediction_tokens = prediction.split()
|
|
ground_truth_tokens = ground_truth.split()
|
|
common = collections.Counter(prediction_tokens) & collections.Counter(ground_truth_tokens)
|
|
num_same = sum(common.values())
|
|
if num_same == 0:
|
|
return 0
|
|
precision = 1.0 * num_same / len(prediction_tokens)
|
|
recall = 1.0 * num_same / len(ground_truth_tokens)
|
|
f1 = (2 * precision * recall) / (precision + recall)
|
|
return f1
|
|
|
|
def exact_match(prediction, ground_truth):
|
|
return prediction == ground_truth
|
|
|
|
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
|
scores_for_ground_truths = []
|
|
for idx, ground_truth in enumerate(ground_truths):
|
|
score = metric_fn(prediction, ground_truth)
|
|
scores_for_ground_truths.append(score)
|
|
return max(scores_for_ground_truths)
|
|
|
|
def computeF1(outputs, targets):
|
|
return sum([metric_max_over_ground_truths(f1_score, o, t) for o, t in zip(outputs, targets)])/len(outputs) * 100
|
|
|
|
def computeEM(outputs, targets):
|
|
outs = [metric_max_over_ground_truths(exact_match, o, t) for o, t in zip(outputs, targets)]
|
|
return sum(outs)/len(outputs) * 100
|
|
|
|
def computeBLEU(outputs, targets):
|
|
targets = [[t[i] for t in targets] for i in range(len(targets[0]))]
|
|
return corpus_bleu(outputs, targets, lowercase=True).score
|
|
|
|
|
|
class Rouge(Rouge155):
|
|
"""Rouge calculator class with custom command-line options."""
|
|
|
|
# See full list of options here:
|
|
# https://github.com/andersjo/pyrouge/blob/master/tools/ROUGE-1.5.5/README.txt#L82
|
|
DEFAULT_OPTIONS = [
|
|
'-a', # evaluate all systems
|
|
'-n', 4, # max-ngram
|
|
'-x', # do not calculate ROUGE-L
|
|
'-2', 4, # max-gap-length
|
|
'-u', # include unigram in skip-bigram
|
|
'-c', 95, # confidence interval
|
|
'-r', 1000, # number-of-samples (for resampling)
|
|
'-f', 'A', # scoring formula
|
|
'-p', 0.5, # 0 <= alpha <=1
|
|
'-t', 0, # count by token instead of sentence
|
|
'-d', # print per evaluation scores
|
|
]
|
|
|
|
def __init__(self, n_words=None,
|
|
keep_files=False, options=None):
|
|
|
|
if options is None:
|
|
self.options = self.DEFAULT_OPTIONS.copy()
|
|
else:
|
|
self.options = options
|
|
|
|
if n_words:
|
|
options.extend(["-l", n_words])
|
|
|
|
stem = "-m" in self.options
|
|
|
|
super(Rouge, self).__init__(
|
|
n_words=n_words, stem=stem,
|
|
keep_files=keep_files)
|
|
|
|
def _run_rouge(self):
|
|
# Get full options
|
|
options = (
|
|
['-e', self._rouge_data] +
|
|
list(map(str, self.options)) +
|
|
[os.path.join(self._config_dir, "settings.xml")])
|
|
|
|
logging.info("Running ROUGE with options {}".format(" ".join(options)))
|
|
# print([self._rouge_bin] + list(options))
|
|
pipes = Popen([self._rouge_bin] + options, stdout=PIPE, stderr=PIPE)
|
|
std_out, std_err = pipes.communicate()
|
|
|
|
div_by_zero_error = std_err.decode("utf-8").\
|
|
startswith("Illegal division by zero")
|
|
if pipes.returncode == 0 or div_by_zero_error:
|
|
# Still returns the correct output even with div by zero
|
|
return std_out
|
|
else:
|
|
raise ValueError(
|
|
std_out.decode("utf-8") + "\n" + std_err.decode("utf-8"))
|
|
|
|
|
|
def computeROUGE(greedy, answer):
|
|
rouges = compute_rouge_scores(greedy, answer)
|
|
if len(rouges) > 0:
|
|
avg_rouges = {}
|
|
for key in rouges[0].keys():
|
|
avg_rouges[key] = sum(
|
|
[r.get(key, 0.0) for r in rouges]) / len(rouges) * 100
|
|
else:
|
|
avg_rouges = None
|
|
return avg_rouges
|
|
|
|
|
|
def split_sentences(txt, splitchar=".", include_splitchar=False):
|
|
"""Split sentences of a text based on a given EOS char."""
|
|
out = [s.split() for s in txt.strip().split(splitchar) if len(s) > 0]
|
|
return out
|
|
|
|
def compute_rouge_scores(summs, refs, splitchar='.', options=None, parallel=True):
|
|
assert len(summs) == len(refs)
|
|
options = [
|
|
'-a', # evaluate all systems
|
|
'-c', 95, # confidence interval
|
|
'-m', # use Porter stemmer
|
|
'-n', 2, # max-ngram
|
|
'-w', 1.3, # weight (weighting factor for WLCS)
|
|
]
|
|
rr = Rouge(options=options)
|
|
rouge_args = []
|
|
for summ, ref in zip(summs, refs):
|
|
letter = "A"
|
|
ref_dict = {}
|
|
for r in ref:
|
|
ref_dict[letter] = [x for x in split_sentences(r, splitchar) if len(x) > 0]
|
|
letter = chr(ord(letter) + 1)
|
|
s = [x for x in split_sentences(summ, splitchar) if len(x) > 0]
|
|
rouge_args.append((s, ref_dict))
|
|
if parallel:
|
|
with closing(Pool(cpu_count()//2)) as pool:
|
|
rouge_scores = pool.starmap(rr.score_summary, rouge_args)
|
|
else:
|
|
rouge_scores = []
|
|
for s, a in rouge_args:
|
|
rouge_scores.append(rr.score_summary(s, ref_dict))
|
|
return rouge_scores
|
|
|
|
|
|
def to_delta_state(line):
|
|
delta_state = {'inform': {}, 'request': {}}
|
|
try:
|
|
if line == 'None' or line.strip() == '':
|
|
return delta_state
|
|
inform, request = [[y.strip() for y in x.strip().split(',')] for x in line.split(';')]
|
|
inform_pairs = {}
|
|
for i in inform:
|
|
try:
|
|
k, v = i.split(':')
|
|
inform_pairs[k.strip()] = v.strip()
|
|
except:
|
|
pass
|
|
delta_state = {'inform': inform_pairs, 'request': request}
|
|
except:
|
|
pass
|
|
finally:
|
|
return delta_state
|
|
|
|
|
|
def update_state(state, delta):
|
|
for act, slot in delta.items():
|
|
state[act] = slot
|
|
return state
|
|
|
|
|
|
def dict_cmp(d1, d2):
|
|
def cmp(a, b):
|
|
for k1, v1 in a.items():
|
|
if k1 not in b:
|
|
return False
|
|
else:
|
|
if v1 != b[k1]:
|
|
return False
|
|
return True
|
|
return cmp(d1, d2) and cmp(d2, d1)
|
|
|
|
|
|
def computeDialogue(greedy, answer):
|
|
examples = []
|
|
for idx, (g, a) in enumerate(zip(greedy, answer)):
|
|
examples.append((a[0][0], g, a[0][1], idx))
|
|
examples.sort()
|
|
turn_request_positives = 0
|
|
turn_goal_positives = 0
|
|
joint_goal_positives = 0
|
|
ldt = None
|
|
for ex in examples:
|
|
if ldt is None or ldt.split('_')[:-1] != ex[0].split('_')[:-1]:
|
|
state, answer_state = {}, {}
|
|
ldt = ex[0]
|
|
delta_state = to_delta_state(ex[1])
|
|
answer_delta_state = to_delta_state(ex[2])
|
|
state = update_state(state, delta_state['inform'])
|
|
answer_state = update_state(answer_state, answer_delta_state['inform'])
|
|
if dict_cmp(state, answer_state):
|
|
joint_goal_positives += 1
|
|
if delta_state['request'] == answer_delta_state['request']:
|
|
turn_request_positives += 1
|
|
if dict_cmp(delta_state['inform'], answer_delta_state['inform']):
|
|
turn_goal_positives += 1
|
|
|
|
joint_goal_em = joint_goal_positives / len(examples) * 100
|
|
turn_request_em = turn_request_positives / len(examples) * 100
|
|
turn_goal_em = turn_goal_positives / len(examples) * 100
|
|
answer = [(x[-1], x[-2]) for x in examples]
|
|
answer.sort()
|
|
answer = [[x[1]] for x in answer]
|
|
return joint_goal_em, turn_request_em, turn_goal_em, answer
|
|
|
|
|
|
def compute_metrics(greedy, answer, rouge=False, bleu=False, corpus_f1=False, logical_form=False, args=None, dialogue=False):
|
|
metric_keys = []
|
|
metric_values = []
|
|
if not isinstance(answer[0], list):
|
|
answer = [[a] for a in answer]
|
|
if logical_form:
|
|
lfem, answer = computeLFEM(greedy, answer, args)
|
|
metric_keys += ['lfem']
|
|
metric_values += [lfem]
|
|
if dialogue:
|
|
joint_goal_em, request_em, turn_goal_em, answer = computeDialogue(greedy, answer)
|
|
avg_dialogue = (joint_goal_em + request_em) / 2
|
|
metric_keys += ['joint_goal_em', 'turn_request_em', 'turn_goal_em', 'avg_dialogue']
|
|
metric_values += [joint_goal_em, request_em, turn_goal_em, avg_dialogue]
|
|
em = computeEM(greedy, answer)
|
|
metric_keys += ['em']
|
|
metric_values += [em]
|
|
if bleu:
|
|
bleu = computeBLEU(greedy, answer)
|
|
metric_keys.append('bleu')
|
|
metric_values.append(bleu)
|
|
if rouge:
|
|
rouge = computeROUGE(greedy, answer)
|
|
metric_keys += ['rouge1', 'rouge2', 'rougeL', 'avg_rouge']
|
|
avg_rouge = (rouge['rouge_1_f_score'] + rouge['rouge_2_f_score'] + rouge['rouge_l_f_score']) / 3
|
|
metric_values += [rouge['rouge_1_f_score'], rouge['rouge_2_f_score'], rouge['rouge_l_f_score'], avg_rouge]
|
|
norm_greedy = [normalize_text(g) for g in greedy]
|
|
norm_answer = [[normalize_text(a) for a in al] for al in answer]
|
|
nf1 = computeF1(norm_greedy, norm_answer)
|
|
nem = computeEM(norm_greedy, norm_answer)
|
|
metric_keys.extend(['nf1', 'nem'])
|
|
metric_values.extend([nf1, nem])
|
|
if corpus_f1:
|
|
corpus_f1, precision, recall = computeCF1(norm_greedy, norm_answer)
|
|
metric_keys += ['corpus_f1', 'precision', 'recall']
|
|
metric_values += [corpus_f1, precision, recall]
|
|
metric_dict = collections.OrderedDict(list(zip(metric_keys, metric_values)))
|
|
return metric_dict, answer
|