rm Variable from sapg

This commit is contained in:
Bryan Marcus McCann 2018-09-18 00:41:45 +00:00 committed by Bryan McCann
parent ee3ba6c1b5
commit 4a962fa2bc
1 changed files with 9 additions and 51 deletions

View File

@ -5,7 +5,6 @@ import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
from .common import positional_encodings_like, INF, EPSILON, TransformerEncoder, TransformerDecoder, PackedLSTM, LSTMDecoderAttention, LSTMDecoder, Embedding, Feedforward, mask
@ -121,10 +120,10 @@ class SelfAttentivePointerGenerator(nn.Module):
effective_vocab_size = self.generative_vocab_size + len(oov_to_limited_idx)
if self.generative_vocab_size < effective_vocab_size:
size[-1] = effective_vocab_size - self.generative_vocab_size
buff = Variable(scaled_p_vocab.data.new(*size).fill_(EPSILON))
buff = scaled_p_vocab.new_full(size, EPSILON)
scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim()-1)
p_context_ptr = Variable(scaled_p_vocab.data.new(*scaled_p_vocab.size()).fill_(EPSILON))
p_context_ptr = scaled_p_vocab.new_full(scaled_p_vocab.size(), EPSILON)
p_context_ptr.scatter_add_(p_context_ptr.dim()-1, context_indices.unsqueeze(1).expand_as(context_attention), context_attention)
scaled_p_context_ptr = (1 - vocab_pointer_switches).expand_as(p_context_ptr) * p_context_ptr
@ -135,9 +134,8 @@ class SelfAttentivePointerGenerator(nn.Module):
def greedy(self, self_attended_context, context, context_indices, oov_to_limited_idx, rnn_state=None):
B, TC, C = context.size()
T = self.args.max_output_length
outs = Variable(context.data.new(B, T).long().fill_(
self.field.decoder_stoi['<pad>']), volatile=True)
hiddens = [Variable(self_attended_context[0].data.new(B, T, C).zero_(), volatile=True)
outs = context.new_full((B, T), self.field.decoder_stoi['<pad>'], dtype=torch.long)
hiddens = [self_attended_context[0].new_zeros((B, T, C))
for l in range(len(self.self_attentive_decoder.layers) + 1)]
hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0])
eos_yet = context.data.new(B).byte().zero_()
@ -145,9 +143,8 @@ class SelfAttentivePointerGenerator(nn.Module):
rnn_output, context_alignment = None, None
for t in range(T):
if t == 0:
embedding = self.decoder_embeddings(Variable(
self_attended_context[-1].data.new(B).long().fill_(
self.field.vocab.stoi['<init>']), volatile=True).unsqueeze(1), [1]*B)
embedding = self.decoder_embeddings(
self_attended_context[-1].new_full((B, 1), self.field.vocab.stoi['<init>'], dtype=torch.long), [1]*B)
else:
embedding = self.decoder_embeddings(outs[:, t - 1].unsqueeze(1), [1]*B)
hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze(1)
@ -167,52 +164,13 @@ class SelfAttentivePointerGenerator(nn.Module):
context_indices,
oov_to_limited_idx)
pred_probs, preds = probs.max(-1)
eos_yet = eos_yet | (preds.data == self.field.decoder_stoi['<eos>'])
outs[:, t] = Variable(preds.data.cpu().apply_(self.map_to_full), volatile=True)
preds = preds.squeeze(1)
eos_yet = eos_yet | (preds == self.field.decoder_stoi['<eos>'])
outs[:, t] = preds.cpu().apply_(self.map_to_full)
if eos_yet.all():
break
return outs
class CoattentiveLayer(nn.Module):
def __init__(self, d, dropout=0.2):
super().__init__()
self.proj = Feedforward(d, d, dropout=0.0)
self.embed_sentinel = nn.Embedding(2, d)
self.dropout = nn.Dropout(dropout)
def forward(self, context, question, context_padding, question_padding):
context_padding = torch.cat([context.data.new(context.size(0)).long().fill_(0).unsqueeze(1).long()==1, context_padding], 1)
question_padding = torch.cat([question.data.new(question.size(0)).long().fill_(0).unsqueeze(1)==1, question_padding], 1)
context_sentinel = self.embed_sentinel(Variable(context.data.new(context.size(0)).long().fill_(0)))
context = torch.cat([context_sentinel.unsqueeze(1), self.dropout(context)], 1) # batch_size x (context_length + 1) x features
question_sentinel = self.embed_sentinel(Variable(question.data.new(question.size(0)).long().fill_(1)))
question = torch.cat([question_sentinel.unsqueeze(1), question], 1) # batch_size x (question_length + 1) x features
question = F.tanh(self.proj(question)) # batch_size x (question_length + 1) x features
affinity = context.bmm(question.transpose(1,2)) # batch_size x (context_length + 1) x (question_length + 1)
attn_over_context = self.normalize(affinity, context_padding) # batch_size x (context_length + 1) x 1
attn_over_question = self.normalize(affinity.transpose(1,2), question_padding) # batch_size x (question_length + 1) x 1
sum_of_context = self.attn(attn_over_context, context) # batch_size x (question_length + 1) x features
sum_of_question = self.attn(attn_over_question, question) # batch_size x (context_length + 1) x features
coattn_context = self.attn(attn_over_question, sum_of_context) # batch_size x (context_length + 1) x features
return torch.cat([coattn_context, sum_of_question], 2)[:, 1:]
@staticmethod
def attn(weights, candidates):
w1, w2, w3 = weights.size()
c1, c2, c3 = candidates.size()
return weights.unsqueeze(3).expand(w1, w2, w3, c3).mul(candidates.unsqueeze(2).expand(c1, c2, w3, c3)).sum(1).squeeze(1)
@staticmethod
def normalize(original, padding):
raw_scores = original.clone()
raw_scores.data.masked_fill_(padding.unsqueeze(-1).expand_as(raw_scores), -INF)
return F.softmax(raw_scores, dim=1)
class DualPtrRNNDecoder(nn.Module):
def __init__(self, d_in, d_hid, dropout=0.0, num_layers=1):