Add differentiable BLEU loss

-Using differentiable BLEU loss instead of cross_entropy loss
-it helps decreasing train-test evaluation gap
mehrad 2018-11-27 15:22:38 -08:00
@ -83,8 +83,9 @@ def parse():
parser.add_argument('--reverse', action='store_true', help='if token_testing and true, sorts all iterators in reverse')
parser.add_argument('--reverse_task', action='store_true', dest='reverse_task_bool', help='whether to translate english to code or the other way around')
parser.add_argument('--skip_cache', action='store_true', dest='skip_cache_bool', help='whether use exisiting cached splits or generate new ones')
parser.add_argument('--skip_cache', action='store_true', dest='skip_cache_bool', help='whether to use exisiting cached splits or generate new ones')
parser.add_argument('--lr_rate', default=0.001, type=float, help='initial_learning_rate')
parser.add_argument('--use_bleu_loss', action='store_true', help='whether to use differentiable BLEU loss or not')

@ -7,9 +7,11 @@ from torch import nn
from torch.nn import functional as F
from util import get_trainable_params
from modules import expectedBLEU, expectedMultiBleu, matrixBLEU
from cove import MTLSTM
from allennlp.modules.elmo import Elmo
options_file = ""
weight_file = ""
@ -74,7 +76,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
def forward(self, batch):
def forward(self, batch, iteration):
context, context_lengths, context_limited = batch.context, batch.context_lengths, batch.context_limited
question, question_lengths, question_limited = batch.question, batch.question_lengths, batch.question_limited
answer, answer_lengths, answer_limited = batch.answer, batch.answer_lengths, batch.answer_limited
@ -134,6 +136,17 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=pad_idx)
if self.args.use_bleu_loss and iteration >= 2.0/3 * max(self.args.train_iterations):
# if self.args.use_bleu_loss and iteration >= 1.0 / 3 * max(self.args.train_iterations):
max_order = 4
answer = answer[0][1:]
target = targets[0]
batch_size = 1
translation_len = answer.shape
loss = expectedMultiBleu.bleu(answer, torch.LongTensor(target), torch.FloatTensor([translation_len] * batch_size), translation_len, max_order=max_order, smooth=True)
loss = F.nll_loss(probs.log(), targets)
return loss, None

@ -0,0 +1,61 @@
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from copy import deepcopy
from collections import Counter
from copy import deepcopy as copy
from modules.matrixBLEU import mBLEU
from modules.utils import CUDA_wrapper
import itertools
from functools import reduce
from modules.utils import LongTensor, FloatTensor
import time
def one_hots(zeros, ix):
for i in range(zeros.size()[0]):
zeros[i, ix[i]] = 1
return zeros
def overlap(t, r_hot, r, f, temp, n):
""" calculate overlap as in original BLEU script but expected.
see google's nmt BLEU script for details """
t_soft = f(t / temp)
length = t.size()[0]
v_size = t.size()[1]
from_ref = list([[0] for i in r])
from_ref_t = LongTensor(from_ref)
mapper_ref = {j:i for i, j in enumerate(from_ref)}
res = CUDA_wrapper(Variable(FloatTensor([0])))
M = [[from_ref[i + j] for j in range(n)] for i in range(len(from_ref) - n + 1)]
mul = lambda x, y: x * y
start_all = time.time()
for i in range(length - n + 1):
start_select_t_soft = time.time()
pp = [t_soft[i + j] for j in range(n)]
ngram_calc_cum = 0
for m in M:
reslicer = lambda x:[0] + x
ngram_calc_start = time.time()
y_prod = reduce(mul,
[r_hot[j:reslicer(-n + 1 + j), m[j]] for j in range(n)]) # j is id of current word in sentense
y_prod = y_prod.sum(0)
p_prod = reduce(mul, \
[t_soft[j:reslicer(-n + 1 + j), m[j]] for j in range(n)])
denominator = 1 + p_prod.sum(0) - p_prod[i]
ngram_calc_cum += time.time() - ngram_calc_start
pr = reduce(mul, [pp[j][m[j]] for j in range(n)])
res += torch.min(pr, pr * y_prod / denominator)
return res
def precision(t, r_hot, r, f, temp, n):
return overlap(t, r_hot, r, f, temp, n) / ([0] - n + 1)
def bleu(t, r_hot, r, f, temp, n):
precisions = [precision(t, r_hot, r, f, temp, i) for i in range(1, n+1)]
p_log_sum = sum([(1. / n) * torch.log(p)\
for p in precisions])
return torch.exp(p_log_sum)

@ -0,0 +1,164 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from collections import Counter
from copy import deepcopy as copy_deep
from copy import copy as copy
from modules.matrixBLEU import mBLEU
from modules.utils import CUDA_wrapper
from collections import Counter
from modules.utils import LongTensor, FloatTensor
from functools import reduce
from modules.utils import CUDA_wrapper
import sys
def eprint(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)
class Reslicer:
def __init__(self, max_lenght):
This functor is used to prevent empty reslice
of index selecting when it appears to be zero
self.max_l = max_lenght
def __call__(self, x):
return self.max_l - x
def ngrams_product(A, n):
A-is probability matrix
[batch x length_candidate_translation x reference_len]
third dimention is reference's words in order of appearence in reference
n - states for n-grams
Output: [batch, (length_candidate_translation-n+1) x (reference_len-n+1)]
max_l = min(A.size()[1:])
reslicer = Reslicer(max_l)
if reslicer(n-1) <= 0:
return None
cur = A[:, :reslicer(n-1), :reslicer(n-1)].clone()
for i in range(1, n):
mul = A[:, i:reslicer(n-1-i), i:reslicer(n-1-i)]
cur = cur * mul
return cur
def get_selected_matrices(probs, references, dim=1):
batched index select
probs - is a matrix
references - is index
dim - is dimention of element of the batch
# NOTE for loop in index select. Found only this way to do this.
# It seems that it could be optimized via batched version of index_select
# but there is no batched_index_select in pytorch for now
return[torch.index_select(a, dim, Variable(LongTensor(i))).unsqueeze(0)\
for a, i in zip(probs, references)])
def ngram_ref_counts(reference, lengths, n):
For each position counts n-grams equal to n-gram to this position
reference - matrix sequences of id's from vocabulary.[batch, ref len]
NOTE reference should be padded with some special ids
At least one value in length must be equal reference.shape[1]
output: counts n-grams for each start position padded with zeros
res = []
max_len = max(lengths)
if max_len - n + 1 <= 0:
return None
for r, l in zip(reference, lengths):
picked = set() # we only take into account first appearance of n-gram
# (which contains its count of occurrence)
current_length = l - n + 1
cnt = Counter([tuple([r[i + j] for j in range(n)]) \
for i in range(current_length)])
occurrence = []
for i in range(current_length):
n_gram = tuple([r[i + j] for j in range(n)])
val = 0
if not n_gram in picked:
val = cnt[n_gram]
padding = [0 for _ in range(max_len - l if current_length > 0\
else max_len - n+ 1)]
res.append(occurrence + padding)
return Variable(FloatTensor(res), requires_grad=False)
def calculate_overlap(p, r, n, lengths):
p - probability tensor [b x len_x x reference_length]
r - references, tensor [b x len_y]
contains word's ids for each reference in batch
n - n-gram
lenghts - lengths of each reference in batch
A = ngrams_product(get_selected_matrices(p, r), n)
r_cnt = ngram_ref_counts(r, lengths, n)
if A is None or r_cnt is None:
return CUDA_wrapper(torch.zeros(p.shape[0]))
r_cnt = r_cnt[:, None]
A_div = -A + torch.sum(A, 1, keepdim=True) + 1
second_arg = r_cnt / A_div
term = torch.min(A, A * second_arg)
return torch.sum(torch.sum(term, 2), 1)
def bleu(p, r, translation_lengths, reference_lengths, max_order=4, smooth=False):
p - matrix with probabilityes
r - reference batch
reference_lengths - lengths of the references
max_order - max order of n-gram
smooth - smooth calculation of precisions
translation_lengths - torch tensor
overlaps_list = []
translation_length = sum(translation_lengths)
reference_length = sum(reference_lengths)
for n in range(1, max_order + 1):
overlaps_list.append(calculate_overlap(p, r, n, reference_lengths))
overlaps = torch.stack(overlaps_list)
matches_by_order = torch.sum(overlaps, 1)
possible_matches_by_order = torch.zeros(max_order)
for n in range(1, max_order + 1):
cur_pm = translation_lengths.float() - n + 1
mask = cur_pm > 0
cur_pm *= mask.float()
possible_matches_by_order[n - 1] = torch.sum(cur_pm)
precisions = Variable(FloatTensor([0] * max_order))
for i in range(max_order):
if smooth:
precisions[i] = (matches_by_order[i] + 1) /\
(possible_matches_by_order[i] + 1)
if possible_matches_by_order[i] > 0:
precisions[i] = matches_by_order[i] /\
precisions[i] = Variable(FloatTensor([0]))
if torch.min(precisions[:max_order]).data[0] > 0:
p_log_sum = sum([(1. / max_order) * torch.log(p) for p in precisions])
geo_mean = torch.exp(p_log_sum)
geo_mean = torch.pow(\
reduce(lambda x, y: x*y, precisions), 1./max_order)
eprint('WARNING: some precision(s) is zero')
ratio = float(translation_length) / reference_length
if ratio > 1.0:
bp = 1.0
MIN_BP = 1E-2
bp = np.exp(1 - 1. / ratio)
bp = MIN_BP
bleu = -geo_mean * bp
return bleu, precisions

@ -0,0 +1,114 @@
import torch
from torch.nn import functional
from torch.autograd import Variable
import numpy as np
import os
from functools import reduce
from copy import deepcopy as copy
import time
from modules.utils import CUDA_wrapper
from modules.utils import SoftmaxWithTemperature
from modules.utils import fill_eye_diag
class mBLEU:
def __init__(self, max_order=4, softmax_temperature=0.001, T_argmax=True,\
"""class implementing straightforwad matrix BLEU computation"""
self.max_order = max_order
self.T_argmax = T_argmax = SoftmaxWithTemperature(softmax_temperature)
self.softmax_regular = torch.nn.Softmax()
self.std_temp = std_temp
def __call__(self, R, T, reference_corpus_lens, translation_corpus_lens):
T[b x t x v]
R[b x r]
reference_corpus_lens - list, len=b
translation_corpus_lens - list, len=b
max_order = self.max_order
shapeR =
shapeT =
translation_length = sum(translation_corpus_lens)
reference_length = sum(reference_corpus_lens)
if self.T_argmax:
cur_temperature = None
if self.std_temp:
cur_temperature = T.std()
if (np.random.rand(1)[0] > 0.99):
T =, shapeT[2]),\
TR = T.bmm(R.transpose(1, 2))
TT = T.bmm(T.transpose(1, 2))
# TT = fill_eye_diag(TT)
reference_len = sum(reference_corpus_lens)
tanslation_len = sum(translation_corpus_lens)
matches_by_order = [CUDA_wrapper(Variable(torch.FloatTensor([0])))\
for i in range(max_order)]
cur_t = TT
cur_tr = TR
all_t = [torch.sum(cur_t, 1)]
all_tr = [torch.sum(cur_tr, 2)]
def overlapper(t, tr):
return torch.sum((torch.min(t, tr) + SMOOTH_CONST) / torch.max(\
(t + SMOOTH_CONST),CUDA_wrapper(Variable(\
torch.FloatTensor([1])))), 1)
overlap = overlapper(all_t[-1], all_tr[-1])
matches_by_order[0] = torch.sum(overlap)
possible_matches_by_order = [
for i in range(max_order)\
def update_possible_matches(possible_matches_by_order,\
translation_corpus_lens, order):
for transl_len in translation_corpus_lens:
possible_matches = transl_len - order
if possible_matches > 0:
possible_matches_by_order[order] += possible_matches
translation_corpus_lens, 0)
for order in range(1, min(max_order, shapeT[1], shapeR[1])):
cur_t = TT[:, order:, order:] * cur_t[:, :-1, :-1]
all_t.append(torch.sum(cur_t, 1))
cur_tr = TR[:, order:, order:] * cur_tr[:, :-1, :-1]
all_tr.append(torch.sum(cur_tr, 2))
overlap = overlapper(all_t[-1], all_tr[-1])
matches_by_order[order] = torch.sum(overlap)
translation_corpus_lens, order)
precisions = [CUDA_wrapper(Variable(torch.FloatTensor([0])))\
for i in range(max_order)]
for i in range(0, max_order):
if possible_matches_by_order[i].data[0] > 0:
if i > 0:
precisions[i] = ((matches_by_order[i].float() + 1)\
/( possible_matches_by_order[i] + 1))
precisions[i] = (matches_by_order[i].float()\
precisions[i] = CUDA_wrapper(Variable(torch.FloatTensor([0])))
if torch.min(torch.stack(precisions)).data[0] > 1E-3:
p_log_sum = sum([(1. / max_order) * torch.log(p)\
for p in precisions])
geo_mean = torch.exp(p_log_sum)
geo_mean = torch.pow(\
reduce(lambda x, y: x*y, precisions), 1./max_order)
ratio = float(translation_length) / reference_length
if ratio > 1.0:
bp = 1.
MIN_BP = 1E-2
bp = np.exp(1 - 1. / ratio)
bp = MIN_BP
bleu = -geo_mean * bp
return bleu, precisions

@ -0,0 +1,39 @@
import torch
if torch.cuda.is_available():
Tensor = torch.cuda.FloatTensor
FloatTensor = torch.cuda.FloatTensor
LongTensor = torch.cuda.LongTensor
ByteTensor = torch.cuda.ByteTensor
Tensor = torch.Tensor
FloatTensor = torch.FloatTensor
LongTensor = torch.LongTensor
ByteTensor = torch.ByteTensor
def CUDA_wrapper(tensor):
use_cuda = torch.cuda.is_available()
if use_cuda:
return tensor.cuda()
return tensor
class SoftmaxWithTemperature:
def __init__(self, temperature):
formula: softmax(x/temperature)
self.temperature = temperature
self.softmax = torch.nn.Softmax()
def __call__(self, x, temperature=None):
if not temperature is None:
return self.softmax(x / temperature)
return self.softmax(x / self.temperature)
def fill_eye_diag(a):
_, s1, s2 =
dd = Variable(CUDA_wrapper(torch.eye(s1)))
zero_dd = 1 - dd
return a * zero_dd + dd

@ -132,7 +132,7 @@ def get_learning_rate(i, args):
def step(model, batch, opt, iteration, field, task, lr=None, grad_clip=None, writer=None, it=None):
loss, predictions = model(batch)
loss, predictions = model(batch, iteration)
if lr is not None:
opt.param_groups[0]['lr'] = lr
@ -182,7 +182,7 @@ def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_
(args.load and iteration == start_iteration + 1))):
train_task_val_metric = None
for val_task_idx, (val_task, val_iter) in enumerate(val_iters):
val_loss, metric_dict = validate(val_task, val_iter, model, logger, field, world_size, rank, num_print=args.num_print, args=args)
val_loss, metric_dict = validate(val_task, val_iter, model, logger, field, world_size, rank, iteration, num_print=args.num_print, args=args)
if val_loss is not None:
log_entry = f'{args.timestamp}:{elapsed_time(logger)}:iteration_{iteration}:{round_progress}train_{task}:{task_progress}val_{val_task}:val_loss{val_loss.item():.4f}:'
writer.add_scalars(f'loss/val', {val_task: val_loss.item()}, iteration)

@ -4,11 +4,11 @@ from metrics import compute_metrics
from import get_tokenizer
def compute_validation_outputs(model, val_iter, field, optional_names=[]):
def compute_validation_outputs(model, val_iter, field, iteration, optional_names=[]):
loss, predictions, answers = [], [], []
outputs = [[] for _ in range(len(optional_names))]
for batch_idx, batch in enumerate(val_iter):
l, p = model(batch)
l, p = model(batch, iteration)
predictions.append(pad(p, 150, dim=-1, val=field.vocab.stoi['<pad>']))
a = None
@ -54,8 +54,8 @@ def all_reverse(tensor, world_size, task, field, clip, dim=0):
return field.reverse(tensor)[:clip]
def gather_results(model, val_iter, field, world_size, task, optional_names=[]):
loss, predictions, answers, outputs = compute_validation_outputs(model, val_iter, field, optional_names=optional_names)
def gather_results(model, val_iter, field, world_size, task, iteration, optional_names=[]):
loss, predictions, answers, outputs = compute_validation_outputs(model, val_iter, field, iteration, optional_names=optional_names)
clip = get_clip(val_iter)
if not hasattr(val_iter.dataset.examples[0], 'squad_id') and not hasattr(val_iter.dataset.examples[0], 'wikisql_id') and not hasattr(val_iter.dataset.examples[0], 'woz_id'):
answers = all_reverse(answers, world_size, task, field, clip)
@ -75,12 +75,12 @@ def print_results(keys, values, rank=None, num_print=1):
def validate(task, val_iter, model, logger, field, world_size, rank, num_print=10, args=None):
def validate(task, val_iter, model, logger, field, world_size, rank, iteration, num_print=10, args=None):
with torch.no_grad():
required_names = ['greedy', 'answer']
optional_names = ['context', 'question']
loss, predictions, answers, results = gather_results(model, val_iter, field, world_size, task, optional_names=optional_names)
loss, predictions, answers, results = gather_results(model, val_iter, field, world_size, task, iteration, optional_names=optional_names)
predictions = [p.replace('UNK', 'OOV') for p in predictions]
names = required_names + optional_names
if hasattr(val_iter.dataset.examples[0], 'wikisql_id') or hasattr(val_iter.dataset.examples[0], 'squad_id') or hasattr(val_iter.dataset.examples[0], 'woz_id'):