115 lines
4.9 KiB
Python
Executable File
115 lines
4.9 KiB
Python
Executable File
import torch
|
|
from torch.nn import functional
|
|
from torch.autograd import Variable
|
|
import numpy as np
|
|
import os
|
|
from functools import reduce
|
|
from copy import deepcopy as copy
|
|
import time
|
|
from modules.utils import CUDA_wrapper
|
|
from modules.utils import SoftmaxWithTemperature
|
|
from modules.utils import fill_eye_diag
|
|
|
|
class mBLEU:
|
|
def __init__(self, max_order=4, softmax_temperature=0.001, T_argmax=True,\
|
|
std_temp=False):
|
|
"""class implementing straightforwad matrix BLEU computation"""
|
|
self.max_order = max_order
|
|
self.T_argmax = T_argmax
|
|
self.sm = SoftmaxWithTemperature(softmax_temperature)
|
|
self.softmax_regular = torch.nn.Softmax()
|
|
self.std_temp = std_temp
|
|
|
|
def __call__(self, R, T, reference_corpus_lens, translation_corpus_lens):
|
|
"""
|
|
T[b x t x v]
|
|
R[b x r]
|
|
reference_corpus_lens - list, len=b
|
|
translation_corpus_lens - list, len=b
|
|
"""
|
|
max_order = self.max_order
|
|
shapeR = R.data.shape
|
|
shapeT = T.data.shape
|
|
translation_length = sum(translation_corpus_lens)
|
|
reference_length = sum(reference_corpus_lens)
|
|
if self.T_argmax:
|
|
cur_temperature = None
|
|
if self.std_temp:
|
|
cur_temperature = T.std()
|
|
if (np.random.rand(1)[0] > 0.99):
|
|
print(cur_temperature)
|
|
T = self.sm(T.contiguous().view(-1, shapeT[2]),\
|
|
temperature=cur_temperature).view(shapeT)
|
|
TR = T.bmm(R.transpose(1, 2))
|
|
TT = T.bmm(T.transpose(1, 2))
|
|
# TT = fill_eye_diag(TT)
|
|
|
|
reference_len = sum(reference_corpus_lens)
|
|
tanslation_len = sum(translation_corpus_lens)
|
|
matches_by_order = [CUDA_wrapper(Variable(torch.FloatTensor([0])))\
|
|
for i in range(max_order)]
|
|
cur_t = TT
|
|
cur_tr = TR
|
|
all_t = [torch.sum(cur_t, 1)]
|
|
all_tr = [torch.sum(cur_tr, 2)]
|
|
def overlapper(t, tr):
|
|
SMOOTH_CONST = 1E-10
|
|
return torch.sum((torch.min(t, tr) + SMOOTH_CONST) / torch.max(\
|
|
(t + SMOOTH_CONST),CUDA_wrapper(Variable(\
|
|
torch.FloatTensor([1])))), 1)
|
|
overlap = overlapper(all_t[-1], all_tr[-1])
|
|
matches_by_order[0] = torch.sum(overlap)
|
|
possible_matches_by_order = [
|
|
CUDA_wrapper(Variable(torch.FloatTensor([0])))\
|
|
for i in range(max_order)\
|
|
]
|
|
def update_possible_matches(possible_matches_by_order,\
|
|
translation_corpus_lens, order):
|
|
for transl_len in translation_corpus_lens:
|
|
possible_matches = transl_len - order
|
|
if possible_matches > 0:
|
|
possible_matches_by_order[order] += possible_matches
|
|
update_possible_matches(possible_matches_by_order,\
|
|
translation_corpus_lens, 0)
|
|
for order in range(1, min(max_order, shapeT[1], shapeR[1])):
|
|
cur_t = TT[:, order:, order:] * cur_t[:, :-1, :-1]
|
|
all_t.append(torch.sum(cur_t, 1))
|
|
cur_tr = TR[:, order:, order:] * cur_tr[:, :-1, :-1]
|
|
all_tr.append(torch.sum(cur_tr, 2))
|
|
overlap = overlapper(all_t[-1], all_tr[-1])
|
|
matches_by_order[order] = torch.sum(overlap)
|
|
update_possible_matches(possible_matches_by_order,\
|
|
translation_corpus_lens, order)
|
|
|
|
precisions = [CUDA_wrapper(Variable(torch.FloatTensor([0])))\
|
|
for i in range(max_order)]
|
|
for i in range(0, max_order):
|
|
if possible_matches_by_order[i].data[0] > 0:
|
|
if i > 0:
|
|
precisions[i] = ((matches_by_order[i].float() + 1)\
|
|
/( possible_matches_by_order[i] + 1))
|
|
else:
|
|
precisions[i] = (matches_by_order[i].float()\
|
|
/possible_matches_by_order[i])
|
|
else:
|
|
precisions[i] = CUDA_wrapper(Variable(torch.FloatTensor([0])))
|
|
if torch.min(torch.stack(precisions)).data[0] > 1E-3:
|
|
p_log_sum = sum([(1. / max_order) * torch.log(p)\
|
|
for p in precisions])
|
|
geo_mean = torch.exp(p_log_sum)
|
|
else:
|
|
geo_mean = torch.pow(\
|
|
reduce(lambda x, y: x*y, precisions), 1./max_order)
|
|
ratio = float(translation_length) / reference_length
|
|
if ratio > 1.0:
|
|
bp = 1.
|
|
else:
|
|
THRESHOLD_RATIO = 1E-1
|
|
MIN_BP = 1E-2
|
|
if ratio > THRESHOLD_RATIO:
|
|
bp = np.exp(1 - 1. / ratio)
|
|
else:
|
|
bp = MIN_BP
|
|
bleu = -geo_mean * bp
|
|
return bleu, precisions
|