diff --git a/metrics.py b/metrics.py index 244685a0..44bf603f 100644 --- a/metrics.py +++ b/metrics.py @@ -190,6 +190,14 @@ def fm_score(prediction, ground_truth): return 1.0 return len(common) / len(ground_truth) +def dm_score(prediction, ground_truth): + pred_funcs = get_devices(prediction) + ground_truth = get_devices(ground_truth) + common = collections.Counter(pred_funcs) & collections.Counter(ground_truth) + if not len(ground_truth): + return 1.0 + return len(common) / len(ground_truth) + def exact_match(prediction, ground_truth): return prediction == ground_truth @@ -212,6 +220,11 @@ def computeFM(outputs, targets): outs = [metric_max_over_ground_truths(fm_score, o, t) for o, t in zip(outputs, targets)] return sum(outs) / len(outputs) * 100 +def computeDM(outputs, targets): + outs = [metric_max_over_ground_truths(dm_score, o, t) for o, t in zip(outputs, targets)] + return sum(outs) / len(outputs) * 100 + + def computeBLEU(outputs, targets): targets = [[t[i] for t in targets] for i in range(len(targets[0]))] return corpus_bleu(outputs, targets, lowercase=True).score @@ -392,7 +405,7 @@ def computeDialogue(greedy, answer): return joint_goal_em, turn_request_em, turn_goal_em, answer -def compute_metrics(greedy, answer, rouge=False, bleu=False, corpus_f1=False, logical_form=False, dialogue=False, func_accuracy=False, args=None): +def compute_metrics(greedy, answer, rouge=False, bleu=False, corpus_f1=False, logical_form=False, dialogue=False, func_accuracy=False, dev_accuracy=False, args=None): metric_keys = [] metric_values = [] if not isinstance(answer[0], list): @@ -425,9 +438,14 @@ def compute_metrics(greedy, answer, rouge=False, bleu=False, corpus_f1=False, lo metric_keys.extend(['nf1', 'nem']) metric_values.extend([nf1, nem]) if func_accuracy: - func_accuracy = computeFM(greedy, answer) + function_accuracy = computeFM(greedy, answer) metric_keys.append('fm') - metric_values.append(func_accuracy) + metric_values.append(function_accuracy) + if dev_accuracy: + device_accuracy = computeDM(greedy, answer) + metric_keys.append('dm') + metric_values.append(device_accuracy) + if corpus_f1: corpus_f1, precision, recall = computeCF1(norm_greedy, norm_answer) metric_keys += ['corpus_f1', 'precision', 'recall'] diff --git a/predict.py b/predict.py index 3fc319f3..80a83485 100644 --- a/predict.py +++ b/predict.py @@ -184,7 +184,9 @@ def run(args, field, val_sets, model): bleu='iwslt' in task or 'multi30k' in task or 'almond' in task, dialogue='woz' in task, rouge='cnn' in task, logical_form='sql' in task, corpus_f1='zre' in task, - func_accuracy='almond' in task and not args.reverse_task_bool, args=args) + func_accuracy='almond' in task and not args.reverse_task_bool, + dev_accuracy='almond' in task and not args.reverse_task_bool, + args=args) with open(results_file_name, 'w') as results_file: results_file.write(json.dumps(metrics) + '\n') else: @@ -240,7 +242,7 @@ def get_args(): 'multinli.in.out': 'em', 'squad': 'nf1', 'srl': 'nf1', - 'almond': 'bleu', + 'almond': 'bleu' if args.reverse_task_bool else 'em', 'sst': 'em', 'wikisql': 'lfem', 'woz.en': 'joint_goal_em', @@ -264,7 +266,7 @@ def get_best(args): lines = f.readlines() best_score = 0 - best_it = 0 + best_it = 10 deca_scores = {} for l in lines: if 'val' in l: diff --git a/utils/lang_utils.py b/utils/lang_utils.py index ae3aee56..3a482392 100644 --- a/utils/lang_utils.py +++ b/utils/lang_utils.py @@ -2,4 +2,7 @@ import os import numpy def get_functions(program): - return [x for x in program.split(' ') if x.startswith('@')] \ No newline at end of file + return [x for x in program.split(' ') if x.startswith('@')] + +def get_devices(program): + return [x.rsplit('.', 1)[0] for x in program.split(' ') if x.startswith('@')] \ No newline at end of file diff --git a/validate.py b/validate.py index c8f3a0fe..22eba715 100644 --- a/validate.py +++ b/validate.py @@ -89,7 +89,9 @@ def validate(task, val_iter, model, logger, field, world_size, rank, iteration, bleu='iwslt' in task or 'multi30k' in task or 'almond' in task, dialogue='woz' in task, rouge='cnn' in task, logical_form='sql' in task, corpus_f1='zre' in task, - func_accuracy='almond' in task and not args.reverse_task_bool, args=args) + func_accuracy='almond' in task and not args.reverse_task_bool, + dev_accuracy='almond' in task and not args.reverse_task_bool, + args=args) results = [predictions, answers] + results print_results(names, results, rank=rank, num_print=num_print)