Improve the naming of inputs of `compute_metrics`

This commit is contained in:
Sina 2022-02-24 18:23:19 -08:00
parent 0b3465f502
commit 2330c08ca0
1 changed files with 35 additions and 32 deletions

View File

@ -519,9 +519,11 @@ def computeDialogue(greedy, answer):
return joint_goal_em, turn_request_em, turn_goal_em, answer
def compute_metrics(greedy: Iterable[str], answer: Union[Iterable[str], Iterable[Iterable[str]]], requested_metrics: Iterable, lang):
def compute_metrics(predictions: Iterable[str], answers: Union[Iterable[str], Iterable[Iterable[str]]], requested_metrics: Iterable, lang: str):
"""
Inputs:
predictions: a list of model predictions
answers: a list of gold answers, each answer can be one item, or a list if multiple gold answers exist
requested_metrics: contains a subset of the following metrics
em (exact match)
sm (structure match): valid if the output is ThingTalk code. Whether the gold answer and prediction are identical if we ignore parameter values of ThingTalk programs
@ -533,80 +535,81 @@ def compute_metrics(greedy: Iterable[str], answer: Union[Iterable[str], Iterable
corpus_f1, precision, recall: corpus-level precision, recall and F1 score
lfem
joint_goal_em, turn_request_em, turn_goal_em, avg_dialogue
lang: the language of the predictions and answers. Used for BERTScore.
"""
metric_keys = []
metric_values = []
if not isinstance(answer[0], list):
answer = [[a] for a in answer]
if not isinstance(answers[0], list):
answers = [[a] for a in answers]
if 'lfem' in requested_metrics:
lfem, answer = computeLFEM(greedy, answer)
lfem, answers = computeLFEM(predictions, answers)
metric_keys += ['lfem']
metric_values += [lfem]
if 'joint_goal_em' in requested_metrics:
joint_goal_em, request_em, turn_goal_em, answer = computeDialogue(greedy, answer)
joint_goal_em, request_em, turn_goal_em, answers = computeDialogue(predictions, answers)
avg_dialogue = (joint_goal_em + request_em) / 2
metric_keys += ['joint_goal_em', 'turn_request_em', 'turn_goal_em', 'avg_dialogue']
metric_values += [joint_goal_em, request_em, turn_goal_em, avg_dialogue]
em = computeEM(greedy, answer)
em = computeEM(predictions, answers)
metric_keys += ['em']
metric_values += [em]
if 'pem' in requested_metrics:
pem = computePartialEM(greedy, answer)
pem = computePartialEM(predictions, answers)
metric_keys.append('pem')
metric_values.append(pem)
if 'sm' in requested_metrics:
sm = computeSM(greedy, answer)
sm = computeSM(predictions, answers)
metric_keys.append('sm')
metric_values.append(sm)
if 'ter' in requested_metrics:
ter = computeTER(greedy, answer)
ter = computeTER(predictions, answers)
metric_keys.append('ter')
metric_values.append(ter)
if 'bertscore' in requested_metrics:
bertscore = computeBERTScore(greedy, answer, lang)
bertscore = computeBERTScore(predictions, answers, lang)
metric_keys.append('bertscore')
metric_values.append(bertscore)
if 'casedbleu' in requested_metrics:
casedbleu = computeCasedBLEU(greedy, answer)
casedbleu = computeCasedBLEU(predictions, answers)
metric_keys.append('casedbleu')
metric_values.append(casedbleu)
if 'bleu' in requested_metrics:
bleu = computeBLEU(greedy, answer)
bleu = computeBLEU(predictions, answers)
metric_keys.append('bleu')
metric_values.append(bleu)
if 't5_bleu' in requested_metrics:
t5_bleu = computeT5BLEU(greedy, answer)
t5_bleu = computeT5BLEU(predictions, answers)
metric_keys.append('t5_bleu')
metric_values.append(t5_bleu)
if 'nmt_bleu' in requested_metrics:
nmt_bleu = computeNMTBLEU(greedy, answer)
nmt_bleu = computeNMTBLEU(predictions, answers)
metric_keys.append('nmt_bleu')
metric_values.append(nmt_bleu)
if 'avg_rouge' in requested_metrics:
rouge = computeROUGE(greedy, answer)
rouge = computeROUGE(predictions, answers)
metric_keys += ['rouge1', 'rouge2', 'rougeL', 'avg_rouge']
avg_rouge = (rouge['rouge_1_f_score'] + rouge['rouge_2_f_score'] + rouge['rouge_l_f_score']) / 3
metric_values += [rouge['rouge_1_f_score'], rouge['rouge_2_f_score'], rouge['rouge_l_f_score'], avg_rouge]
if 'sc_precision' in requested_metrics:
precision = computeSequenceClassificationPrecision(greedy, answer)
precision = computeSequenceClassificationPrecision(predictions, answers)
metric_keys.append('sc_precision')
metric_values.append(precision)
if 'sc_recall' in requested_metrics:
recall = computeSequenceClassificationRecall(greedy, answer)
recall = computeSequenceClassificationRecall(predictions, answers)
metric_keys.append('sc_recall')
metric_values.append(recall)
if 'sc_f1' in requested_metrics:
f1 = computeSequenceClassificationF1(greedy, answer)
f1 = computeSequenceClassificationF1(predictions, answers)
metric_keys.append('sc_f1')
metric_values.append(f1)
if 'f1' in requested_metrics:
f1 = computeF1(greedy, answer)
f1 = computeF1(predictions, answers)
metric_keys.append('f1')
metric_values.append(f1)
if 'ner_f1_IOB1' in requested_metrics:
greedy_processed = [pred.split() for pred in greedy]
answer_processed = [ans[0].split() for ans in answer]
predictions_processed = [pred.split() for pred in predictions]
answers_processed = [ans[0].split() for ans in answers]
def convert_IOB2_to_IOB1(labels):
cur_category = None
@ -615,37 +618,37 @@ def compute_metrics(greedy: Iterable[str], answer: Union[Iterable[str], Iterable
labels[n] = "I" + label[1:]
cur_category = label[2:]
convert_IOB2_to_IOB1(greedy_processed)
convert_IOB2_to_IOB1(answer_processed)
convert_IOB2_to_IOB1(predictions_processed)
convert_IOB2_to_IOB1(answers_processed)
f1 = (
seq_metrics.f1_score(y_pred=greedy_processed, y_true=answer_processed, mode='strict', scheme=seq_scheme.IOB1) * 100
seq_metrics.f1_score(y_pred=predictions_processed, y_true=answers_processed, mode='strict', scheme=seq_scheme.IOB1) * 100
)
metric_keys.append('ner_f1_IOB1')
metric_values.append(f1)
if 'ner_f1' in requested_metrics:
greedy_processed = [pred.split() for pred in greedy]
answer_processed = [ans[0].split() for ans in answer]
predictions_processed = [pred.split() for pred in predictions]
answers_processed = [ans[0].split() for ans in answers]
f1 = seq_metrics.f1_score(y_pred=greedy_processed, y_true=answer_processed) * 100
f1 = seq_metrics.f1_score(y_pred=predictions_processed, y_true=answers_processed) * 100
metric_keys.append('ner_f1')
metric_values.append(f1)
norm_greedy = [normalize_text(g) for g in greedy]
norm_answer = [[normalize_text(a) for a in al] for al in answer]
norm_predictions = [normalize_text(g) for g in predictions]
norm_answers = [[normalize_text(a) for a in al] for al in answers]
if 'nf1' in requested_metrics:
nf1 = computeF1(norm_greedy, norm_answer)
nf1 = computeF1(norm_predictions, norm_answers)
metric_keys.append('nf1')
metric_values.append(nf1)
if 'nem' in requested_metrics:
nem = computeEM(norm_greedy, norm_answer)
nem = computeEM(norm_predictions, norm_answers)
metric_keys.append('nem')
metric_values.append(nem)
if 'corpus_f1' in requested_metrics:
corpus_f1, precision, recall = computeCF1(norm_greedy, norm_answer)
corpus_f1, precision, recall = computeCF1(norm_predictions, norm_answers)
metric_keys += ['corpus_f1', 'precision', 'recall']
metric_values += [corpus_f1, precision, recall]