rm Variable from sapg
This commit is contained in:
parent
ee3ba6c1b5
commit
4a962fa2bc
|
@ -5,7 +5,6 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.autograd import Variable
|
|
||||||
|
|
||||||
from .common import positional_encodings_like, INF, EPSILON, TransformerEncoder, TransformerDecoder, PackedLSTM, LSTMDecoderAttention, LSTMDecoder, Embedding, Feedforward, mask
|
from .common import positional_encodings_like, INF, EPSILON, TransformerEncoder, TransformerDecoder, PackedLSTM, LSTMDecoderAttention, LSTMDecoder, Embedding, Feedforward, mask
|
||||||
|
|
||||||
|
@ -121,10 +120,10 @@ class SelfAttentivePointerGenerator(nn.Module):
|
||||||
effective_vocab_size = self.generative_vocab_size + len(oov_to_limited_idx)
|
effective_vocab_size = self.generative_vocab_size + len(oov_to_limited_idx)
|
||||||
if self.generative_vocab_size < effective_vocab_size:
|
if self.generative_vocab_size < effective_vocab_size:
|
||||||
size[-1] = effective_vocab_size - self.generative_vocab_size
|
size[-1] = effective_vocab_size - self.generative_vocab_size
|
||||||
buff = Variable(scaled_p_vocab.data.new(*size).fill_(EPSILON))
|
buff = scaled_p_vocab.new_full(size, EPSILON)
|
||||||
scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim()-1)
|
scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim()-1)
|
||||||
|
|
||||||
p_context_ptr = Variable(scaled_p_vocab.data.new(*scaled_p_vocab.size()).fill_(EPSILON))
|
p_context_ptr = scaled_p_vocab.new_full(scaled_p_vocab.size(), EPSILON)
|
||||||
p_context_ptr.scatter_add_(p_context_ptr.dim()-1, context_indices.unsqueeze(1).expand_as(context_attention), context_attention)
|
p_context_ptr.scatter_add_(p_context_ptr.dim()-1, context_indices.unsqueeze(1).expand_as(context_attention), context_attention)
|
||||||
scaled_p_context_ptr = (1 - vocab_pointer_switches).expand_as(p_context_ptr) * p_context_ptr
|
scaled_p_context_ptr = (1 - vocab_pointer_switches).expand_as(p_context_ptr) * p_context_ptr
|
||||||
|
|
||||||
|
@ -135,9 +134,8 @@ class SelfAttentivePointerGenerator(nn.Module):
|
||||||
def greedy(self, self_attended_context, context, context_indices, oov_to_limited_idx, rnn_state=None):
|
def greedy(self, self_attended_context, context, context_indices, oov_to_limited_idx, rnn_state=None):
|
||||||
B, TC, C = context.size()
|
B, TC, C = context.size()
|
||||||
T = self.args.max_output_length
|
T = self.args.max_output_length
|
||||||
outs = Variable(context.data.new(B, T).long().fill_(
|
outs = context.new_full((B, T), self.field.decoder_stoi['<pad>'], dtype=torch.long)
|
||||||
self.field.decoder_stoi['<pad>']), volatile=True)
|
hiddens = [self_attended_context[0].new_zeros((B, T, C))
|
||||||
hiddens = [Variable(self_attended_context[0].data.new(B, T, C).zero_(), volatile=True)
|
|
||||||
for l in range(len(self.self_attentive_decoder.layers) + 1)]
|
for l in range(len(self.self_attentive_decoder.layers) + 1)]
|
||||||
hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0])
|
hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0])
|
||||||
eos_yet = context.data.new(B).byte().zero_()
|
eos_yet = context.data.new(B).byte().zero_()
|
||||||
|
@ -145,9 +143,8 @@ class SelfAttentivePointerGenerator(nn.Module):
|
||||||
rnn_output, context_alignment = None, None
|
rnn_output, context_alignment = None, None
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
if t == 0:
|
if t == 0:
|
||||||
embedding = self.decoder_embeddings(Variable(
|
embedding = self.decoder_embeddings(
|
||||||
self_attended_context[-1].data.new(B).long().fill_(
|
self_attended_context[-1].new_full((B, 1), self.field.vocab.stoi['<init>'], dtype=torch.long), [1]*B)
|
||||||
self.field.vocab.stoi['<init>']), volatile=True).unsqueeze(1), [1]*B)
|
|
||||||
else:
|
else:
|
||||||
embedding = self.decoder_embeddings(outs[:, t - 1].unsqueeze(1), [1]*B)
|
embedding = self.decoder_embeddings(outs[:, t - 1].unsqueeze(1), [1]*B)
|
||||||
hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze(1)
|
hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze(1)
|
||||||
|
@ -167,52 +164,13 @@ class SelfAttentivePointerGenerator(nn.Module):
|
||||||
context_indices,
|
context_indices,
|
||||||
oov_to_limited_idx)
|
oov_to_limited_idx)
|
||||||
pred_probs, preds = probs.max(-1)
|
pred_probs, preds = probs.max(-1)
|
||||||
eos_yet = eos_yet | (preds.data == self.field.decoder_stoi['<eos>'])
|
preds = preds.squeeze(1)
|
||||||
outs[:, t] = Variable(preds.data.cpu().apply_(self.map_to_full), volatile=True)
|
eos_yet = eos_yet | (preds == self.field.decoder_stoi['<eos>'])
|
||||||
|
outs[:, t] = preds.cpu().apply_(self.map_to_full)
|
||||||
if eos_yet.all():
|
if eos_yet.all():
|
||||||
break
|
break
|
||||||
return outs
|
return outs
|
||||||
|
|
||||||
|
|
||||||
class CoattentiveLayer(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self, d, dropout=0.2):
|
|
||||||
super().__init__()
|
|
||||||
self.proj = Feedforward(d, d, dropout=0.0)
|
|
||||||
self.embed_sentinel = nn.Embedding(2, d)
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
|
||||||
|
|
||||||
def forward(self, context, question, context_padding, question_padding):
|
|
||||||
context_padding = torch.cat([context.data.new(context.size(0)).long().fill_(0).unsqueeze(1).long()==1, context_padding], 1)
|
|
||||||
question_padding = torch.cat([question.data.new(question.size(0)).long().fill_(0).unsqueeze(1)==1, question_padding], 1)
|
|
||||||
|
|
||||||
context_sentinel = self.embed_sentinel(Variable(context.data.new(context.size(0)).long().fill_(0)))
|
|
||||||
context = torch.cat([context_sentinel.unsqueeze(1), self.dropout(context)], 1) # batch_size x (context_length + 1) x features
|
|
||||||
|
|
||||||
question_sentinel = self.embed_sentinel(Variable(question.data.new(question.size(0)).long().fill_(1)))
|
|
||||||
question = torch.cat([question_sentinel.unsqueeze(1), question], 1) # batch_size x (question_length + 1) x features
|
|
||||||
question = F.tanh(self.proj(question)) # batch_size x (question_length + 1) x features
|
|
||||||
|
|
||||||
affinity = context.bmm(question.transpose(1,2)) # batch_size x (context_length + 1) x (question_length + 1)
|
|
||||||
attn_over_context = self.normalize(affinity, context_padding) # batch_size x (context_length + 1) x 1
|
|
||||||
attn_over_question = self.normalize(affinity.transpose(1,2), question_padding) # batch_size x (question_length + 1) x 1
|
|
||||||
sum_of_context = self.attn(attn_over_context, context) # batch_size x (question_length + 1) x features
|
|
||||||
sum_of_question = self.attn(attn_over_question, question) # batch_size x (context_length + 1) x features
|
|
||||||
coattn_context = self.attn(attn_over_question, sum_of_context) # batch_size x (context_length + 1) x features
|
|
||||||
return torch.cat([coattn_context, sum_of_question], 2)[:, 1:]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def attn(weights, candidates):
|
|
||||||
w1, w2, w3 = weights.size()
|
|
||||||
c1, c2, c3 = candidates.size()
|
|
||||||
return weights.unsqueeze(3).expand(w1, w2, w3, c3).mul(candidates.unsqueeze(2).expand(c1, c2, w3, c3)).sum(1).squeeze(1)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def normalize(original, padding):
|
|
||||||
raw_scores = original.clone()
|
|
||||||
raw_scores.data.masked_fill_(padding.unsqueeze(-1).expand_as(raw_scores), -INF)
|
|
||||||
return F.softmax(raw_scores, dim=1)
|
|
||||||
|
|
||||||
class DualPtrRNNDecoder(nn.Module):
|
class DualPtrRNNDecoder(nn.Module):
|
||||||
|
|
||||||
def __init__(self, d_in, d_hid, dropout=0.0, num_layers=1):
|
def __init__(self, d_in, d_hid, dropout=0.0, num_layers=1):
|
||||||
|
|
Loading…
Reference in New Issue