diff --git a/genienlp/metrics.py b/genienlp/metrics.py index 4ce17a1b..68fb54fc 100644 --- a/genienlp/metrics.py +++ b/genienlp/metrics.py @@ -34,7 +34,7 @@ from typing import List, Union import sacrebleu from datasets import load_metric from dialogues import Bitod -from dialogues.bitod.src.evaluate import convert_lists_to_set +from dialogues.bitod.src.evaluate import compute_da, compute_dst_em, compute_ser from seqeval import metrics as seq_metrics from seqeval import scheme as seq_scheme @@ -225,23 +225,24 @@ def compute_e2e_dialogue_score(greedy, answer, tgt_lang, args, example_ids, cont def computeSER(greedy, inputs): - ser = 0.0 - for pred, input in zip(greedy, inputs): - act_values = QUOTED_MATCH_REGEX.findall(input) - missing = False - if len(act_values): - for val in act_values: - if val not in pred: - missing = True - if missing: - ser += 1.0 - return ser / len(greedy) * 100 + act_values = [] + for input in inputs: + act_values += QUOTED_MATCH_REGEX.findall(input) + + return compute_ser(greedy, act_values) + + +def computeDA(greedy, answer): + answer = [a[0] for a in answer] + return compute_da(greedy, answer) def computeJGA(greedy, answer, example_ids): + # Inputs contain diff states, so we need to compute the full state first dataset = Bitod() - hit = 0 cur_dial_id = None + full_answer = [] + full_greedy = [] assert len(example_ids) == len(greedy) == len(answer) for id_, g, a in zip(example_ids, greedy, answer): dial_id = id_.split('/')[1] @@ -257,13 +258,18 @@ def computeJGA(greedy, answer, example_ids): dataset.update_state(a, answer_state) dataset.update_state(g, greedy_state) - answer_state_sets = convert_lists_to_set(answer_state) - greedy_state_sets = convert_lists_to_set(greedy_state) + full_answer.append(answer_state) + full_greedy.append(greedy_state) - if answer_state_sets == greedy_state_sets: - hit += 1 + return compute_dst_em(full_greedy, full_answer) - return hit / len(greedy) * 100 + +def computeDST_EM(greedy, answer): + # Calculate exact match between diff states + dataset = Bitod() + answer = [dataset.span2state(a[0]) for a in answer] + greedy = [dataset.span2state(g) for g in greedy] + return compute_dst_em(greedy, answer) def convert_IOB2_to_IOB1(labels): @@ -337,6 +343,14 @@ def compute_metrics( em = computeEM(predictions, answers) metric_keys += ['em'] metric_values += [em] + if 'da_em' in requested_metrics: + da_em = computeDA(predictions, answers) + metric_keys += ['da_em'] + metric_values += [da_em] + if 'dst_em' in requested_metrics: + dst_em = computeDST_EM(predictions, answers) + metric_keys += ['dst_em'] + metric_values += [dst_em] if 'pem' in requested_metrics: pem = computePartialEM(predictions, answers) metric_keys.append('pem')