genienlp/modules/matrixBLEU.py

115 lines
4.9 KiB
Python
Executable File

import torch
from torch.nn import functional
from torch.autograd import Variable
import numpy as np
import os
from functools import reduce
from copy import deepcopy as copy
import time
from modules.utils import CUDA_wrapper
from modules.utils import SoftmaxWithTemperature
from modules.utils import fill_eye_diag
class mBLEU:
def __init__(self, max_order=4, softmax_temperature=0.001, T_argmax=True,\
std_temp=False):
"""class implementing straightforwad matrix BLEU computation"""
self.max_order = max_order
self.T_argmax = T_argmax
self.sm = SoftmaxWithTemperature(softmax_temperature)
self.softmax_regular = torch.nn.Softmax()
self.std_temp = std_temp
def __call__(self, R, T, reference_corpus_lens, translation_corpus_lens):
"""
T[b x t x v]
R[b x r]
reference_corpus_lens - list, len=b
translation_corpus_lens - list, len=b
"""
max_order = self.max_order
shapeR = R.data.shape
shapeT = T.data.shape
translation_length = sum(translation_corpus_lens)
reference_length = sum(reference_corpus_lens)
if self.T_argmax:
cur_temperature = None
if self.std_temp:
cur_temperature = T.std()
if (np.random.rand(1)[0] > 0.99):
print(cur_temperature)
T = self.sm(T.contiguous().view(-1, shapeT[2]),\
temperature=cur_temperature).view(shapeT)
TR = T.bmm(R.transpose(1, 2))
TT = T.bmm(T.transpose(1, 2))
# TT = fill_eye_diag(TT)
reference_len = sum(reference_corpus_lens)
tanslation_len = sum(translation_corpus_lens)
matches_by_order = [CUDA_wrapper(Variable(torch.FloatTensor([0])))\
for i in range(max_order)]
cur_t = TT
cur_tr = TR
all_t = [torch.sum(cur_t, 1)]
all_tr = [torch.sum(cur_tr, 2)]
def overlapper(t, tr):
SMOOTH_CONST = 1E-10
return torch.sum((torch.min(t, tr) + SMOOTH_CONST) / torch.max(\
(t + SMOOTH_CONST),CUDA_wrapper(Variable(\
torch.FloatTensor([1])))), 1)
overlap = overlapper(all_t[-1], all_tr[-1])
matches_by_order[0] = torch.sum(overlap)
possible_matches_by_order = [
CUDA_wrapper(Variable(torch.FloatTensor([0])))\
for i in range(max_order)\
]
def update_possible_matches(possible_matches_by_order,\
translation_corpus_lens, order):
for transl_len in translation_corpus_lens:
possible_matches = transl_len - order
if possible_matches > 0:
possible_matches_by_order[order] += possible_matches
update_possible_matches(possible_matches_by_order,\
translation_corpus_lens, 0)
for order in range(1, min(max_order, shapeT[1], shapeR[1])):
cur_t = TT[:, order:, order:] * cur_t[:, :-1, :-1]
all_t.append(torch.sum(cur_t, 1))
cur_tr = TR[:, order:, order:] * cur_tr[:, :-1, :-1]
all_tr.append(torch.sum(cur_tr, 2))
overlap = overlapper(all_t[-1], all_tr[-1])
matches_by_order[order] = torch.sum(overlap)
update_possible_matches(possible_matches_by_order,\
translation_corpus_lens, order)
precisions = [CUDA_wrapper(Variable(torch.FloatTensor([0])))\
for i in range(max_order)]
for i in range(0, max_order):
if possible_matches_by_order[i].data[0] > 0:
if i > 0:
precisions[i] = ((matches_by_order[i].float() + 1)\
/( possible_matches_by_order[i] + 1))
else:
precisions[i] = (matches_by_order[i].float()\
/possible_matches_by_order[i])
else:
precisions[i] = CUDA_wrapper(Variable(torch.FloatTensor([0])))
if torch.min(torch.stack(precisions)).data[0] > 1E-3:
p_log_sum = sum([(1. / max_order) * torch.log(p)\
for p in precisions])
geo_mean = torch.exp(p_log_sum)
else:
geo_mean = torch.pow(\
reduce(lambda x, y: x*y, precisions), 1./max_order)
ratio = float(translation_length) / reference_length
if ratio > 1.0:
bp = 1.
else:
THRESHOLD_RATIO = 1E-1
MIN_BP = 1E-2
if ratio > THRESHOLD_RATIO:
bp = np.exp(1 - 1. / ratio)
else:
bp = MIN_BP
bleu = -geo_mean * bp
return bleu, precisions