metrics: first attempt to unify dialogue metrics between genienlp and dialogues
This commit is contained in:
parent
95d7d5630c
commit
77ec8526cc
|
@ -34,7 +34,7 @@ from typing import List, Union
|
|||
import sacrebleu
|
||||
from datasets import load_metric
|
||||
from dialogues import Bitod
|
||||
from dialogues.bitod.src.evaluate import convert_lists_to_set
|
||||
from dialogues.bitod.src.evaluate import compute_da, compute_dst_em, compute_ser
|
||||
from seqeval import metrics as seq_metrics
|
||||
from seqeval import scheme as seq_scheme
|
||||
|
||||
|
@ -225,23 +225,24 @@ def compute_e2e_dialogue_score(greedy, answer, tgt_lang, args, example_ids, cont
|
|||
|
||||
|
||||
def computeSER(greedy, inputs):
|
||||
ser = 0.0
|
||||
for pred, input in zip(greedy, inputs):
|
||||
act_values = QUOTED_MATCH_REGEX.findall(input)
|
||||
missing = False
|
||||
if len(act_values):
|
||||
for val in act_values:
|
||||
if val not in pred:
|
||||
missing = True
|
||||
if missing:
|
||||
ser += 1.0
|
||||
return ser / len(greedy) * 100
|
||||
act_values = []
|
||||
for input in inputs:
|
||||
act_values += QUOTED_MATCH_REGEX.findall(input)
|
||||
|
||||
return compute_ser(greedy, act_values)
|
||||
|
||||
|
||||
def computeDA(greedy, answer):
|
||||
answer = [a[0] for a in answer]
|
||||
return compute_da(greedy, answer)
|
||||
|
||||
|
||||
def computeJGA(greedy, answer, example_ids):
|
||||
# Inputs contain diff states, so we need to compute the full state first
|
||||
dataset = Bitod()
|
||||
hit = 0
|
||||
cur_dial_id = None
|
||||
full_answer = []
|
||||
full_greedy = []
|
||||
assert len(example_ids) == len(greedy) == len(answer)
|
||||
for id_, g, a in zip(example_ids, greedy, answer):
|
||||
dial_id = id_.split('/')[1]
|
||||
|
@ -257,13 +258,18 @@ def computeJGA(greedy, answer, example_ids):
|
|||
dataset.update_state(a, answer_state)
|
||||
dataset.update_state(g, greedy_state)
|
||||
|
||||
answer_state_sets = convert_lists_to_set(answer_state)
|
||||
greedy_state_sets = convert_lists_to_set(greedy_state)
|
||||
full_answer.append(answer_state)
|
||||
full_greedy.append(greedy_state)
|
||||
|
||||
if answer_state_sets == greedy_state_sets:
|
||||
hit += 1
|
||||
return compute_dst_em(full_greedy, full_answer)
|
||||
|
||||
return hit / len(greedy) * 100
|
||||
|
||||
def computeDST_EM(greedy, answer):
|
||||
# Calculate exact match between diff states
|
||||
dataset = Bitod()
|
||||
answer = [dataset.span2state(a[0]) for a in answer]
|
||||
greedy = [dataset.span2state(g) for g in greedy]
|
||||
return compute_dst_em(greedy, answer)
|
||||
|
||||
|
||||
def convert_IOB2_to_IOB1(labels):
|
||||
|
@ -337,6 +343,14 @@ def compute_metrics(
|
|||
em = computeEM(predictions, answers)
|
||||
metric_keys += ['em']
|
||||
metric_values += [em]
|
||||
if 'da_em' in requested_metrics:
|
||||
da_em = computeDA(predictions, answers)
|
||||
metric_keys += ['da_em']
|
||||
metric_values += [da_em]
|
||||
if 'dst_em' in requested_metrics:
|
||||
dst_em = computeDST_EM(predictions, answers)
|
||||
metric_keys += ['dst_em']
|
||||
metric_values += [dst_em]
|
||||
if 'pem' in requested_metrics:
|
||||
pem = computePartialEM(predictions, answers)
|
||||
metric_keys.append('pem')
|
||||
|
|
Loading…
Reference in New Issue