165 lines
5.9 KiB
Python
165 lines
5.9 KiB
Python
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
from torch.autograd import Variable
|
|
from collections import Counter
|
|
from copy import deepcopy as copy_deep
|
|
from copy import copy as copy
|
|
from modules.matrixBLEU import mBLEU
|
|
from modules.utils import CUDA_wrapper
|
|
from collections import Counter
|
|
from modules.utils import LongTensor, FloatTensor
|
|
from functools import reduce
|
|
from modules.utils import CUDA_wrapper
|
|
import sys
|
|
|
|
def eprint(*args, **kwargs):
|
|
print(*args, file=sys.stderr, **kwargs)
|
|
|
|
class Reslicer:
|
|
def __init__(self, max_lenght):
|
|
"""
|
|
This functor is used to prevent empty reslice
|
|
of index selecting when it appears to be zero
|
|
"""
|
|
self.max_l = max_lenght
|
|
|
|
def __call__(self, x):
|
|
return self.max_l - x
|
|
|
|
def ngrams_product(A, n):
|
|
"""
|
|
A-is probability matrix
|
|
[batch x length_candidate_translation x reference_len]
|
|
third dimension is reference's words in order of appearance in reference
|
|
n - states for n-grams
|
|
Output: [batch, (length_candidate_translation-n+1) x (reference_len-n+1)]
|
|
"""
|
|
max_l = min(A.size()[1:])
|
|
reslicer = Reslicer(max_l)
|
|
if reslicer(n-1) <= 0:
|
|
return None
|
|
cur = A[:, :reslicer(n-1), :reslicer(n-1)].clone()
|
|
for i in range(1, n):
|
|
mul = A[:, i:reslicer(n-1-i), i:reslicer(n-1-i)]
|
|
cur = cur * mul
|
|
return cur
|
|
|
|
def get_selected_matrices(probs, references, dim=1):
|
|
"""
|
|
batched index select
|
|
probs - is a matrix
|
|
references - is index
|
|
dim - is dimention of element of the batch
|
|
"""
|
|
# NOTE for loop in index select. Found only this way to do this.
|
|
# It seems that it could be optimized via batched version of index_select
|
|
# but there is no batched_index_select in pytorch for now
|
|
return torch.cat([torch.index_select(a, dim, Variable(LongTensor(i))).unsqueeze(0)\
|
|
for a, i in zip(probs, references)])
|
|
|
|
|
|
def ngram_ref_counts(reference, lengths, n):
|
|
"""
|
|
For each position counts n-grams equal to n-gram to this position
|
|
reference - matrix sequences of id's from vocabulary.[batch, ref len]
|
|
NOTE reference should be padded with some special ids
|
|
At least one value in length must be equal reference.shape[1]
|
|
output: counts n-grams for each start position padded with zeros
|
|
"""
|
|
res = []
|
|
max_len = max(lengths)
|
|
if max_len - n + 1 <= 0:
|
|
return None
|
|
for r, l in zip(reference, lengths):
|
|
picked = set() # we only take into account first appearance of n-gram
|
|
# (which contains its count of occurrence)
|
|
current_length = l - n + 1
|
|
cnt = Counter([tuple([r[i + j] for j in range(n)]) \
|
|
for i in range(current_length)])
|
|
occurrence = []
|
|
for i in range(current_length):
|
|
n_gram = tuple([r[i + j] for j in range(n)])
|
|
val = 0
|
|
if not n_gram in picked:
|
|
val = cnt[n_gram]
|
|
picked.add(n_gram)
|
|
occurrence.append(val)
|
|
padding = [1 for _ in range(max_len - l if current_length > 0\
|
|
else max_len - n+ 1)]
|
|
res.append(occurrence + padding)
|
|
return Variable(FloatTensor(res), requires_grad=False)
|
|
|
|
def calculate_overlap(p, r, n, lengths):
|
|
"""
|
|
p - probability tensor [b x len_x x reference_length]
|
|
r - references, tensor [b x len_y]
|
|
contains word's ids for each reference in batch
|
|
n - n-gram
|
|
lenghts - lengths of each reference in batch
|
|
"""
|
|
A = ngrams_product(get_selected_matrices(p, r), n)
|
|
r_cnt = ngram_ref_counts(r, lengths, n)
|
|
if A is None or r_cnt is None:
|
|
return CUDA_wrapper(torch.zeros(p.shape[0]))
|
|
r_cnt = r_cnt[:, None]
|
|
A_div = -A + torch.sum(A, 1, keepdim=True) + 1
|
|
second_arg = r_cnt / A_div
|
|
term = torch.min(A, A * second_arg)
|
|
return torch.sum(torch.sum(term, 2), 1)
|
|
|
|
def bleu(p, r, translation_lengths, reference_lengths, max_order=4, smooth=False):
|
|
"""
|
|
p - matrix with probabilityes
|
|
r - reference batch
|
|
reference_lengths - lengths of the references
|
|
max_order - max order of n-gram
|
|
smooth - smooth calculation of precisions
|
|
translation_lengths - torch tensor
|
|
"""
|
|
overlaps_list = []
|
|
translation_length = sum(translation_lengths)
|
|
reference_length = sum(reference_lengths)
|
|
for n in range(1, max_order + 1):
|
|
overlaps_list.append(calculate_overlap(p, r, n, reference_lengths))
|
|
overlaps = torch.stack(overlaps_list)
|
|
matches_by_order = torch.sum(overlaps, 1)
|
|
possible_matches_by_order = torch.zeros(max_order)
|
|
for n in range(1, max_order + 1):
|
|
cur_pm = translation_lengths.float() - n + 1
|
|
mask = cur_pm > 0
|
|
cur_pm *= mask.float()
|
|
possible_matches_by_order[n - 1] = torch.sum(cur_pm)
|
|
precisions = Variable(FloatTensor([0] * max_order))
|
|
for i in range(max_order):
|
|
if smooth:
|
|
precisions[i] = (matches_by_order[i] + 1) /\
|
|
(possible_matches_by_order[i] + 1)
|
|
else:
|
|
if possible_matches_by_order[i] > 0:
|
|
precisions[i] = matches_by_order[i] /\
|
|
possible_matches_by_order[i]
|
|
else:
|
|
precisions[i] = Variable(FloatTensor([0]))
|
|
if torch.min(precisions[:max_order]).item() > 0:
|
|
p_log_sum = sum([(1. / max_order) * torch.log(p) for p in precisions])
|
|
geo_mean = torch.exp(p_log_sum)
|
|
else:
|
|
geo_mean = torch.pow(\
|
|
reduce(lambda x, y: x*y, precisions), 1./max_order)
|
|
eprint('WARNING: some precision(s) is zero')
|
|
ratio = float(translation_length) / reference_length
|
|
if ratio > 1.0:
|
|
bp = 1.0
|
|
else:
|
|
THRESHOLD_RATIO = 1E-1
|
|
MIN_BP = 1E-2
|
|
if ratio > THRESHOLD_RATIO:
|
|
bp = np.exp(1 - 1. / ratio)
|
|
else:
|
|
bp = MIN_BP
|
|
bleu = -geo_mean * bp
|
|
return bleu, precisions
|