adding device accuracy

This commit is contained in:
mehrad 2018-11-29 10:17:25 -08:00
parent cd4782574d
commit 787b0887dc
4 changed files with 33 additions and 8 deletions

View File

@ -190,6 +190,14 @@ def fm_score(prediction, ground_truth):
return 1.0
return len(common) / len(ground_truth)
def dm_score(prediction, ground_truth):
pred_funcs = get_devices(prediction)
ground_truth = get_devices(ground_truth)
common = collections.Counter(pred_funcs) & collections.Counter(ground_truth)
if not len(ground_truth):
return 1.0
return len(common) / len(ground_truth)
def exact_match(prediction, ground_truth):
return prediction == ground_truth
@ -212,6 +220,11 @@ def computeFM(outputs, targets):
outs = [metric_max_over_ground_truths(fm_score, o, t) for o, t in zip(outputs, targets)]
return sum(outs) / len(outputs) * 100
def computeDM(outputs, targets):
outs = [metric_max_over_ground_truths(dm_score, o, t) for o, t in zip(outputs, targets)]
return sum(outs) / len(outputs) * 100
def computeBLEU(outputs, targets):
targets = [[t[i] for t in targets] for i in range(len(targets[0]))]
return corpus_bleu(outputs, targets, lowercase=True).score
@ -392,7 +405,7 @@ def computeDialogue(greedy, answer):
return joint_goal_em, turn_request_em, turn_goal_em, answer
def compute_metrics(greedy, answer, rouge=False, bleu=False, corpus_f1=False, logical_form=False, dialogue=False, func_accuracy=False, args=None):
def compute_metrics(greedy, answer, rouge=False, bleu=False, corpus_f1=False, logical_form=False, dialogue=False, func_accuracy=False, dev_accuracy=False, args=None):
metric_keys = []
metric_values = []
if not isinstance(answer[0], list):
@ -425,9 +438,14 @@ def compute_metrics(greedy, answer, rouge=False, bleu=False, corpus_f1=False, lo
metric_keys.extend(['nf1', 'nem'])
metric_values.extend([nf1, nem])
if func_accuracy:
func_accuracy = computeFM(greedy, answer)
function_accuracy = computeFM(greedy, answer)
metric_keys.append('fm')
metric_values.append(func_accuracy)
metric_values.append(function_accuracy)
if dev_accuracy:
device_accuracy = computeDM(greedy, answer)
metric_keys.append('dm')
metric_values.append(device_accuracy)
if corpus_f1:
corpus_f1, precision, recall = computeCF1(norm_greedy, norm_answer)
metric_keys += ['corpus_f1', 'precision', 'recall']

View File

@ -184,7 +184,9 @@ def run(args, field, val_sets, model):
bleu='iwslt' in task or 'multi30k' in task or 'almond' in task,
dialogue='woz' in task,
rouge='cnn' in task, logical_form='sql' in task, corpus_f1='zre' in task,
func_accuracy='almond' in task and not args.reverse_task_bool, args=args)
func_accuracy='almond' in task and not args.reverse_task_bool,
dev_accuracy='almond' in task and not args.reverse_task_bool,
args=args)
with open(results_file_name, 'w') as results_file:
results_file.write(json.dumps(metrics) + '\n')
else:
@ -240,7 +242,7 @@ def get_args():
'multinli.in.out': 'em',
'squad': 'nf1',
'srl': 'nf1',
'almond': 'bleu',
'almond': 'bleu' if args.reverse_task_bool else 'em',
'sst': 'em',
'wikisql': 'lfem',
'woz.en': 'joint_goal_em',
@ -264,7 +266,7 @@ def get_best(args):
lines = f.readlines()
best_score = 0
best_it = 0
best_it = 10
deca_scores = {}
for l in lines:
if 'val' in l:

View File

@ -2,4 +2,7 @@ import os
import numpy
def get_functions(program):
return [x for x in program.split(' ') if x.startswith('@')]
return [x for x in program.split(' ') if x.startswith('@')]
def get_devices(program):
return [x.rsplit('.', 1)[0] for x in program.split(' ') if x.startswith('@')]

View File

@ -89,7 +89,9 @@ def validate(task, val_iter, model, logger, field, world_size, rank, iteration,
bleu='iwslt' in task or 'multi30k' in task or 'almond' in task,
dialogue='woz' in task,
rouge='cnn' in task, logical_form='sql' in task, corpus_f1='zre' in task,
func_accuracy='almond' in task and not args.reverse_task_bool, args=args)
func_accuracy='almond' in task and not args.reverse_task_bool,
dev_accuracy='almond' in task and not args.reverse_task_bool,
args=args)
results = [predictions, answers] + results
print_results(names, results, rank=rank, num_print=num_print)