metrics: first attempt to unify dialogue metrics between genienlp and dialogues

This commit is contained in:
mehrad 2022-05-13 15:40:53 -07:00
parent 95d7d5630c
commit 77ec8526cc
No known key found for this signature in database
GPG Key ID: AAF81F778210AE42
1 changed files with 32 additions and 18 deletions

View File

@ -34,7 +34,7 @@ from typing import List, Union
import sacrebleu
from datasets import load_metric
from dialogues import Bitod
from dialogues.bitod.src.evaluate import convert_lists_to_set
from dialogues.bitod.src.evaluate import compute_da, compute_dst_em, compute_ser
from seqeval import metrics as seq_metrics
from seqeval import scheme as seq_scheme
@ -225,23 +225,24 @@ def compute_e2e_dialogue_score(greedy, answer, tgt_lang, args, example_ids, cont
def computeSER(greedy, inputs):
ser = 0.0
for pred, input in zip(greedy, inputs):
act_values = QUOTED_MATCH_REGEX.findall(input)
missing = False
if len(act_values):
for val in act_values:
if val not in pred:
missing = True
if missing:
ser += 1.0
return ser / len(greedy) * 100
act_values = []
for input in inputs:
act_values += QUOTED_MATCH_REGEX.findall(input)
return compute_ser(greedy, act_values)
def computeDA(greedy, answer):
answer = [a[0] for a in answer]
return compute_da(greedy, answer)
def computeJGA(greedy, answer, example_ids):
# Inputs contain diff states, so we need to compute the full state first
dataset = Bitod()
hit = 0
cur_dial_id = None
full_answer = []
full_greedy = []
assert len(example_ids) == len(greedy) == len(answer)
for id_, g, a in zip(example_ids, greedy, answer):
dial_id = id_.split('/')[1]
@ -257,13 +258,18 @@ def computeJGA(greedy, answer, example_ids):
dataset.update_state(a, answer_state)
dataset.update_state(g, greedy_state)
answer_state_sets = convert_lists_to_set(answer_state)
greedy_state_sets = convert_lists_to_set(greedy_state)
full_answer.append(answer_state)
full_greedy.append(greedy_state)
if answer_state_sets == greedy_state_sets:
hit += 1
return compute_dst_em(full_greedy, full_answer)
return hit / len(greedy) * 100
def computeDST_EM(greedy, answer):
# Calculate exact match between diff states
dataset = Bitod()
answer = [dataset.span2state(a[0]) for a in answer]
greedy = [dataset.span2state(g) for g in greedy]
return compute_dst_em(greedy, answer)
def convert_IOB2_to_IOB1(labels):
@ -337,6 +343,14 @@ def compute_metrics(
em = computeEM(predictions, answers)
metric_keys += ['em']
metric_values += [em]
if 'da_em' in requested_metrics:
da_em = computeDA(predictions, answers)
metric_keys += ['da_em']
metric_values += [da_em]
if 'dst_em' in requested_metrics:
dst_em = computeDST_EM(predictions, answers)
metric_keys += ['dst_em']
metric_values += [dst_em]
if 'pem' in requested_metrics:
pem = computePartialEM(predictions, answers)
metric_keys.append('pem')