adding device accuracy
This commit is contained in:
parent
cd4782574d
commit
787b0887dc
24
metrics.py
24
metrics.py
|
@ -190,6 +190,14 @@ def fm_score(prediction, ground_truth):
|
|||
return 1.0
|
||||
return len(common) / len(ground_truth)
|
||||
|
||||
def dm_score(prediction, ground_truth):
|
||||
pred_funcs = get_devices(prediction)
|
||||
ground_truth = get_devices(ground_truth)
|
||||
common = collections.Counter(pred_funcs) & collections.Counter(ground_truth)
|
||||
if not len(ground_truth):
|
||||
return 1.0
|
||||
return len(common) / len(ground_truth)
|
||||
|
||||
def exact_match(prediction, ground_truth):
|
||||
return prediction == ground_truth
|
||||
|
||||
|
@ -212,6 +220,11 @@ def computeFM(outputs, targets):
|
|||
outs = [metric_max_over_ground_truths(fm_score, o, t) for o, t in zip(outputs, targets)]
|
||||
return sum(outs) / len(outputs) * 100
|
||||
|
||||
def computeDM(outputs, targets):
|
||||
outs = [metric_max_over_ground_truths(dm_score, o, t) for o, t in zip(outputs, targets)]
|
||||
return sum(outs) / len(outputs) * 100
|
||||
|
||||
|
||||
def computeBLEU(outputs, targets):
|
||||
targets = [[t[i] for t in targets] for i in range(len(targets[0]))]
|
||||
return corpus_bleu(outputs, targets, lowercase=True).score
|
||||
|
@ -392,7 +405,7 @@ def computeDialogue(greedy, answer):
|
|||
return joint_goal_em, turn_request_em, turn_goal_em, answer
|
||||
|
||||
|
||||
def compute_metrics(greedy, answer, rouge=False, bleu=False, corpus_f1=False, logical_form=False, dialogue=False, func_accuracy=False, args=None):
|
||||
def compute_metrics(greedy, answer, rouge=False, bleu=False, corpus_f1=False, logical_form=False, dialogue=False, func_accuracy=False, dev_accuracy=False, args=None):
|
||||
metric_keys = []
|
||||
metric_values = []
|
||||
if not isinstance(answer[0], list):
|
||||
|
@ -425,9 +438,14 @@ def compute_metrics(greedy, answer, rouge=False, bleu=False, corpus_f1=False, lo
|
|||
metric_keys.extend(['nf1', 'nem'])
|
||||
metric_values.extend([nf1, nem])
|
||||
if func_accuracy:
|
||||
func_accuracy = computeFM(greedy, answer)
|
||||
function_accuracy = computeFM(greedy, answer)
|
||||
metric_keys.append('fm')
|
||||
metric_values.append(func_accuracy)
|
||||
metric_values.append(function_accuracy)
|
||||
if dev_accuracy:
|
||||
device_accuracy = computeDM(greedy, answer)
|
||||
metric_keys.append('dm')
|
||||
metric_values.append(device_accuracy)
|
||||
|
||||
if corpus_f1:
|
||||
corpus_f1, precision, recall = computeCF1(norm_greedy, norm_answer)
|
||||
metric_keys += ['corpus_f1', 'precision', 'recall']
|
||||
|
|
|
@ -184,7 +184,9 @@ def run(args, field, val_sets, model):
|
|||
bleu='iwslt' in task or 'multi30k' in task or 'almond' in task,
|
||||
dialogue='woz' in task,
|
||||
rouge='cnn' in task, logical_form='sql' in task, corpus_f1='zre' in task,
|
||||
func_accuracy='almond' in task and not args.reverse_task_bool, args=args)
|
||||
func_accuracy='almond' in task and not args.reverse_task_bool,
|
||||
dev_accuracy='almond' in task and not args.reverse_task_bool,
|
||||
args=args)
|
||||
with open(results_file_name, 'w') as results_file:
|
||||
results_file.write(json.dumps(metrics) + '\n')
|
||||
else:
|
||||
|
@ -240,7 +242,7 @@ def get_args():
|
|||
'multinli.in.out': 'em',
|
||||
'squad': 'nf1',
|
||||
'srl': 'nf1',
|
||||
'almond': 'bleu',
|
||||
'almond': 'bleu' if args.reverse_task_bool else 'em',
|
||||
'sst': 'em',
|
||||
'wikisql': 'lfem',
|
||||
'woz.en': 'joint_goal_em',
|
||||
|
@ -264,7 +266,7 @@ def get_best(args):
|
|||
lines = f.readlines()
|
||||
|
||||
best_score = 0
|
||||
best_it = 0
|
||||
best_it = 10
|
||||
deca_scores = {}
|
||||
for l in lines:
|
||||
if 'val' in l:
|
||||
|
|
|
@ -2,4 +2,7 @@ import os
|
|||
import numpy
|
||||
|
||||
def get_functions(program):
|
||||
return [x for x in program.split(' ') if x.startswith('@')]
|
||||
return [x for x in program.split(' ') if x.startswith('@')]
|
||||
|
||||
def get_devices(program):
|
||||
return [x.rsplit('.', 1)[0] for x in program.split(' ') if x.startswith('@')]
|
|
@ -89,7 +89,9 @@ def validate(task, val_iter, model, logger, field, world_size, rank, iteration,
|
|||
bleu='iwslt' in task or 'multi30k' in task or 'almond' in task,
|
||||
dialogue='woz' in task,
|
||||
rouge='cnn' in task, logical_form='sql' in task, corpus_f1='zre' in task,
|
||||
func_accuracy='almond' in task and not args.reverse_task_bool, args=args)
|
||||
func_accuracy='almond' in task and not args.reverse_task_bool,
|
||||
dev_accuracy='almond' in task and not args.reverse_task_bool,
|
||||
args=args)
|
||||
results = [predictions, answers] + results
|
||||
print_results(names, results, rank=rank, num_print=num_print)
|
||||
|
||||
|
|
Loading…
Reference in New Issue