metrics: fix SER and JGA

This commit is contained in:
mehrad 2022-05-16 15:08:12 -07:00
parent 77ec8526cc
commit db482ddb4b
No known key found for this signature in database
GPG Key ID: AAF81F778210AE42
1 changed files with 4 additions and 4 deletions

View File

@ -26,7 +26,7 @@
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import copy
import logging
from collections import Counter, OrderedDict, defaultdict
from typing import List, Union
@ -227,7 +227,7 @@ def compute_e2e_dialogue_score(greedy, answer, tgt_lang, args, example_ids, cont
def computeSER(greedy, inputs):
act_values = []
for input in inputs:
act_values += QUOTED_MATCH_REGEX.findall(input)
act_values.append(QUOTED_MATCH_REGEX.findall(input))
return compute_ser(greedy, act_values)
@ -258,8 +258,8 @@ def computeJGA(greedy, answer, example_ids):
dataset.update_state(a, answer_state)
dataset.update_state(g, greedy_state)
full_answer.append(answer_state)
full_greedy.append(greedy_state)
full_answer.append(copy.deepcopy(answer_state))
full_greedy.append(copy.deepcopy(greedy_state))
return compute_dst_em(full_greedy, full_answer)