66 lines
2.3 KiB
Python
Executable File
66 lines
2.3 KiB
Python
Executable File
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
from torch.autograd import Variable
|
|
from copy import deepcopy
|
|
from collections import Counter
|
|
from copy import deepcopy as copy
|
|
from modules.matrixBLEU import mBLEU
|
|
from modules.utils import CUDA_wrapper
|
|
import itertools
|
|
from functools import reduce
|
|
from modules.utils import LongTensor, FloatTensor
|
|
import time
|
|
|
|
import matplotlib
|
|
matplotlib.use('PDF')
|
|
import matplotlib.pyplot as plt
|
|
|
|
def one_hots(zeros, ix):
|
|
for i in range(zeros.size()[0]):
|
|
zeros[i, ix[i]] = 1
|
|
return zeros
|
|
|
|
def overlap(t, r_hot, r, f, temp, n):
|
|
""" calculate overlap as in original BLEU script but expected.
|
|
see google's nmt bleu.py BLEU script for details """
|
|
t_soft = f(t / temp)
|
|
length = t.size()[0]
|
|
v_size = t.size()[1]
|
|
from_ref = list([i.data[0] for i in r])
|
|
from_ref_t = LongTensor(from_ref)
|
|
mapper_ref = {j:i for i, j in enumerate(from_ref)}
|
|
res = CUDA_wrapper(Variable(FloatTensor([0])))
|
|
M = [[from_ref[i + j] for j in range(n)] for i in range(len(from_ref) - n + 1)]
|
|
mul = lambda x, y: x * y
|
|
start_all = time.time()
|
|
for i in range(length - n + 1):
|
|
start_select_t_soft = time.time()
|
|
pp = [t_soft[i + j] for j in range(n)]
|
|
ngram_calc_cum = 0
|
|
for m in M:
|
|
reslicer = lambda x: r.data.shape[0] + x
|
|
ngram_calc_start = time.time()
|
|
y_prod = reduce(mul,
|
|
[r_hot[j:reslicer(-n + 1 + j), m[j]] for j in range(n)]) # j is id of current word in sentense
|
|
y_prod = y_prod.sum(0)
|
|
p_prod = reduce(mul, \
|
|
[t_soft[j:reslicer(-n + 1 + j), m[j]] for j in range(n)])
|
|
denominator = 1 + p_prod.sum(0) - p_prod[i]
|
|
ngram_calc_cum += time.time() - ngram_calc_start
|
|
pr = reduce(mul, [pp[j][m[j]] for j in range(n)])
|
|
res += torch.min(pr, pr * y_prod / denominator)
|
|
return res
|
|
|
|
def precision(t, r_hot, r, f, temp, n):
|
|
return overlap(t, r_hot, r, f, temp, n) / (t.data.shape[0] - n + 1)
|
|
|
|
def bleu(t, r_hot, r, f, temp, n):
|
|
precisions = [precision(t, r_hot, r, f, temp, i) for i in range(1, n+1)]
|
|
p_log_sum = sum([(1. / n) * torch.log(p)\
|
|
for p in precisions])
|
|
return torch.exp(p_log_sum)
|