Improve the naming of inputs of `compute_metrics`
This commit is contained in:
parent
0b3465f502
commit
2330c08ca0
|
@ -519,9 +519,11 @@ def computeDialogue(greedy, answer):
|
|||
return joint_goal_em, turn_request_em, turn_goal_em, answer
|
||||
|
||||
|
||||
def compute_metrics(greedy: Iterable[str], answer: Union[Iterable[str], Iterable[Iterable[str]]], requested_metrics: Iterable, lang):
|
||||
def compute_metrics(predictions: Iterable[str], answers: Union[Iterable[str], Iterable[Iterable[str]]], requested_metrics: Iterable, lang: str):
|
||||
"""
|
||||
Inputs:
|
||||
predictions: a list of model predictions
|
||||
answers: a list of gold answers, each answer can be one item, or a list if multiple gold answers exist
|
||||
requested_metrics: contains a subset of the following metrics
|
||||
em (exact match)
|
||||
sm (structure match): valid if the output is ThingTalk code. Whether the gold answer and prediction are identical if we ignore parameter values of ThingTalk programs
|
||||
|
@ -533,80 +535,81 @@ def compute_metrics(greedy: Iterable[str], answer: Union[Iterable[str], Iterable
|
|||
corpus_f1, precision, recall: corpus-level precision, recall and F1 score
|
||||
lfem
|
||||
joint_goal_em, turn_request_em, turn_goal_em, avg_dialogue
|
||||
lang: the language of the predictions and answers. Used for BERTScore.
|
||||
"""
|
||||
metric_keys = []
|
||||
metric_values = []
|
||||
if not isinstance(answer[0], list):
|
||||
answer = [[a] for a in answer]
|
||||
if not isinstance(answers[0], list):
|
||||
answers = [[a] for a in answers]
|
||||
if 'lfem' in requested_metrics:
|
||||
lfem, answer = computeLFEM(greedy, answer)
|
||||
lfem, answers = computeLFEM(predictions, answers)
|
||||
metric_keys += ['lfem']
|
||||
metric_values += [lfem]
|
||||
if 'joint_goal_em' in requested_metrics:
|
||||
joint_goal_em, request_em, turn_goal_em, answer = computeDialogue(greedy, answer)
|
||||
joint_goal_em, request_em, turn_goal_em, answers = computeDialogue(predictions, answers)
|
||||
avg_dialogue = (joint_goal_em + request_em) / 2
|
||||
metric_keys += ['joint_goal_em', 'turn_request_em', 'turn_goal_em', 'avg_dialogue']
|
||||
metric_values += [joint_goal_em, request_em, turn_goal_em, avg_dialogue]
|
||||
em = computeEM(greedy, answer)
|
||||
em = computeEM(predictions, answers)
|
||||
metric_keys += ['em']
|
||||
metric_values += [em]
|
||||
if 'pem' in requested_metrics:
|
||||
pem = computePartialEM(greedy, answer)
|
||||
pem = computePartialEM(predictions, answers)
|
||||
metric_keys.append('pem')
|
||||
metric_values.append(pem)
|
||||
if 'sm' in requested_metrics:
|
||||
sm = computeSM(greedy, answer)
|
||||
sm = computeSM(predictions, answers)
|
||||
metric_keys.append('sm')
|
||||
metric_values.append(sm)
|
||||
if 'ter' in requested_metrics:
|
||||
ter = computeTER(greedy, answer)
|
||||
ter = computeTER(predictions, answers)
|
||||
metric_keys.append('ter')
|
||||
metric_values.append(ter)
|
||||
if 'bertscore' in requested_metrics:
|
||||
bertscore = computeBERTScore(greedy, answer, lang)
|
||||
bertscore = computeBERTScore(predictions, answers, lang)
|
||||
metric_keys.append('bertscore')
|
||||
metric_values.append(bertscore)
|
||||
if 'casedbleu' in requested_metrics:
|
||||
casedbleu = computeCasedBLEU(greedy, answer)
|
||||
casedbleu = computeCasedBLEU(predictions, answers)
|
||||
metric_keys.append('casedbleu')
|
||||
metric_values.append(casedbleu)
|
||||
if 'bleu' in requested_metrics:
|
||||
bleu = computeBLEU(greedy, answer)
|
||||
bleu = computeBLEU(predictions, answers)
|
||||
metric_keys.append('bleu')
|
||||
metric_values.append(bleu)
|
||||
if 't5_bleu' in requested_metrics:
|
||||
t5_bleu = computeT5BLEU(greedy, answer)
|
||||
t5_bleu = computeT5BLEU(predictions, answers)
|
||||
metric_keys.append('t5_bleu')
|
||||
metric_values.append(t5_bleu)
|
||||
if 'nmt_bleu' in requested_metrics:
|
||||
nmt_bleu = computeNMTBLEU(greedy, answer)
|
||||
nmt_bleu = computeNMTBLEU(predictions, answers)
|
||||
metric_keys.append('nmt_bleu')
|
||||
metric_values.append(nmt_bleu)
|
||||
if 'avg_rouge' in requested_metrics:
|
||||
rouge = computeROUGE(greedy, answer)
|
||||
rouge = computeROUGE(predictions, answers)
|
||||
metric_keys += ['rouge1', 'rouge2', 'rougeL', 'avg_rouge']
|
||||
avg_rouge = (rouge['rouge_1_f_score'] + rouge['rouge_2_f_score'] + rouge['rouge_l_f_score']) / 3
|
||||
metric_values += [rouge['rouge_1_f_score'], rouge['rouge_2_f_score'], rouge['rouge_l_f_score'], avg_rouge]
|
||||
if 'sc_precision' in requested_metrics:
|
||||
precision = computeSequenceClassificationPrecision(greedy, answer)
|
||||
precision = computeSequenceClassificationPrecision(predictions, answers)
|
||||
metric_keys.append('sc_precision')
|
||||
metric_values.append(precision)
|
||||
if 'sc_recall' in requested_metrics:
|
||||
recall = computeSequenceClassificationRecall(greedy, answer)
|
||||
recall = computeSequenceClassificationRecall(predictions, answers)
|
||||
metric_keys.append('sc_recall')
|
||||
metric_values.append(recall)
|
||||
if 'sc_f1' in requested_metrics:
|
||||
f1 = computeSequenceClassificationF1(greedy, answer)
|
||||
f1 = computeSequenceClassificationF1(predictions, answers)
|
||||
metric_keys.append('sc_f1')
|
||||
metric_values.append(f1)
|
||||
if 'f1' in requested_metrics:
|
||||
f1 = computeF1(greedy, answer)
|
||||
f1 = computeF1(predictions, answers)
|
||||
metric_keys.append('f1')
|
||||
metric_values.append(f1)
|
||||
|
||||
if 'ner_f1_IOB1' in requested_metrics:
|
||||
greedy_processed = [pred.split() for pred in greedy]
|
||||
answer_processed = [ans[0].split() for ans in answer]
|
||||
predictions_processed = [pred.split() for pred in predictions]
|
||||
answers_processed = [ans[0].split() for ans in answers]
|
||||
|
||||
def convert_IOB2_to_IOB1(labels):
|
||||
cur_category = None
|
||||
|
@ -615,37 +618,37 @@ def compute_metrics(greedy: Iterable[str], answer: Union[Iterable[str], Iterable
|
|||
labels[n] = "I" + label[1:]
|
||||
cur_category = label[2:]
|
||||
|
||||
convert_IOB2_to_IOB1(greedy_processed)
|
||||
convert_IOB2_to_IOB1(answer_processed)
|
||||
convert_IOB2_to_IOB1(predictions_processed)
|
||||
convert_IOB2_to_IOB1(answers_processed)
|
||||
f1 = (
|
||||
seq_metrics.f1_score(y_pred=greedy_processed, y_true=answer_processed, mode='strict', scheme=seq_scheme.IOB1) * 100
|
||||
seq_metrics.f1_score(y_pred=predictions_processed, y_true=answers_processed, mode='strict', scheme=seq_scheme.IOB1) * 100
|
||||
)
|
||||
|
||||
metric_keys.append('ner_f1_IOB1')
|
||||
metric_values.append(f1)
|
||||
|
||||
if 'ner_f1' in requested_metrics:
|
||||
greedy_processed = [pred.split() for pred in greedy]
|
||||
answer_processed = [ans[0].split() for ans in answer]
|
||||
predictions_processed = [pred.split() for pred in predictions]
|
||||
answers_processed = [ans[0].split() for ans in answers]
|
||||
|
||||
f1 = seq_metrics.f1_score(y_pred=greedy_processed, y_true=answer_processed) * 100
|
||||
f1 = seq_metrics.f1_score(y_pred=predictions_processed, y_true=answers_processed) * 100
|
||||
|
||||
metric_keys.append('ner_f1')
|
||||
metric_values.append(f1)
|
||||
|
||||
norm_greedy = [normalize_text(g) for g in greedy]
|
||||
norm_answer = [[normalize_text(a) for a in al] for al in answer]
|
||||
norm_predictions = [normalize_text(g) for g in predictions]
|
||||
norm_answers = [[normalize_text(a) for a in al] for al in answers]
|
||||
if 'nf1' in requested_metrics:
|
||||
nf1 = computeF1(norm_greedy, norm_answer)
|
||||
nf1 = computeF1(norm_predictions, norm_answers)
|
||||
metric_keys.append('nf1')
|
||||
metric_values.append(nf1)
|
||||
if 'nem' in requested_metrics:
|
||||
nem = computeEM(norm_greedy, norm_answer)
|
||||
nem = computeEM(norm_predictions, norm_answers)
|
||||
metric_keys.append('nem')
|
||||
metric_values.append(nem)
|
||||
|
||||
if 'corpus_f1' in requested_metrics:
|
||||
corpus_f1, precision, recall = computeCF1(norm_greedy, norm_answer)
|
||||
corpus_f1, precision, recall = computeCF1(norm_predictions, norm_answers)
|
||||
metric_keys += ['corpus_f1', 'precision', 'recall']
|
||||
metric_values += [corpus_f1, precision, recall]
|
||||
|
||||
|
|
Loading…
Reference in New Issue