rm Variable from pg
This commit is contained in:
parent
59a1dcd47d
commit
ee3ba6c1b5
|
@ -172,44 +172,43 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
|
||||
|
||||
def greedy(self, self_attended_context, context, question, context_indices, question_indices, oov_to_limited_idx, rnn_state=None):
|
||||
with torch.no_grad():
|
||||
B, TC, C = context.size()
|
||||
T = self.args.max_output_length
|
||||
outs = context.new_full((B, T), self.field.decoder_stoi['<pad>'], dtype=torch.long)
|
||||
hiddens = [self_attended_context[0].new_zeros((B, T, C))
|
||||
for l in range(len(self.self_attentive_decoder.layers) + 1)]
|
||||
hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0])
|
||||
eos_yet = context.new_zeros((B, )).byte()
|
||||
B, TC, C = context.size()
|
||||
T = self.args.max_output_length
|
||||
outs = context.new_full((B, T), self.field.decoder_stoi['<pad>'], dtype=torch.long)
|
||||
hiddens = [self_attended_context[0].new_zeros((B, T, C))
|
||||
for l in range(len(self.self_attentive_decoder.layers) + 1)]
|
||||
hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0])
|
||||
eos_yet = context.new_zeros((B, )).byte()
|
||||
|
||||
rnn_output, context_alignment, question_alignment = None, None, None
|
||||
for t in range(T):
|
||||
if t == 0:
|
||||
embedding = self.decoder_embeddings(
|
||||
self_attended_context[-1].new_full((B, 1), self.field.vocab.stoi['<init>'], dtype=torch.long), [1]*B)
|
||||
else:
|
||||
embedding = self.decoder_embeddings(outs[:, t - 1].unsqueeze(1), [1]*B)
|
||||
hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze(1)
|
||||
for l in range(len(self.self_attentive_decoder.layers)):
|
||||
hiddens[l + 1][:, t] = self.self_attentive_decoder.layers[l].feedforward(
|
||||
self.self_attentive_decoder.layers[l].attention(
|
||||
self.self_attentive_decoder.layers[l].selfattn(hiddens[l][:, t], hiddens[l][:, :t + 1], hiddens[l][:, :t + 1])
|
||||
, self_attended_context[l], self_attended_context[l]))
|
||||
decoder_outputs = self.dual_ptr_rnn_decoder(hiddens[-1][:, t].unsqueeze(1),
|
||||
context, question,
|
||||
context_alignment=context_alignment, question_alignment=question_alignment,
|
||||
hidden=rnn_state, output=rnn_output)
|
||||
rnn_output, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switch, context_question_switch, rnn_state = decoder_outputs
|
||||
probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch,
|
||||
context_attention, question_attention,
|
||||
context_indices, question_indices,
|
||||
oov_to_limited_idx)
|
||||
pred_probs, preds = probs.max(-1)
|
||||
preds = preds.squeeze(1)
|
||||
eos_yet = eos_yet | (preds == self.field.decoder_stoi['<eos>'])
|
||||
outs[:, t] = preds.cpu().apply_(self.map_to_full)
|
||||
if eos_yet.all():
|
||||
break
|
||||
return outs
|
||||
rnn_output, context_alignment, question_alignment = None, None, None
|
||||
for t in range(T):
|
||||
if t == 0:
|
||||
embedding = self.decoder_embeddings(
|
||||
self_attended_context[-1].new_full((B, 1), self.field.vocab.stoi['<init>'], dtype=torch.long), [1]*B)
|
||||
else:
|
||||
embedding = self.decoder_embeddings(outs[:, t - 1].unsqueeze(1), [1]*B)
|
||||
hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze(1)
|
||||
for l in range(len(self.self_attentive_decoder.layers)):
|
||||
hiddens[l + 1][:, t] = self.self_attentive_decoder.layers[l].feedforward(
|
||||
self.self_attentive_decoder.layers[l].attention(
|
||||
self.self_attentive_decoder.layers[l].selfattn(hiddens[l][:, t], hiddens[l][:, :t + 1], hiddens[l][:, :t + 1])
|
||||
, self_attended_context[l], self_attended_context[l]))
|
||||
decoder_outputs = self.dual_ptr_rnn_decoder(hiddens[-1][:, t].unsqueeze(1),
|
||||
context, question,
|
||||
context_alignment=context_alignment, question_alignment=question_alignment,
|
||||
hidden=rnn_state, output=rnn_output)
|
||||
rnn_output, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switch, context_question_switch, rnn_state = decoder_outputs
|
||||
probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch,
|
||||
context_attention, question_attention,
|
||||
context_indices, question_indices,
|
||||
oov_to_limited_idx)
|
||||
pred_probs, preds = probs.max(-1)
|
||||
preds = preds.squeeze(1)
|
||||
eos_yet = eos_yet | (preds == self.field.decoder_stoi['<eos>'])
|
||||
outs[:, t] = preds.cpu().apply_(self.map_to_full)
|
||||
if eos_yet.all():
|
||||
break
|
||||
return outs
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ import numpy as np
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.autograd import Variable
|
||||
|
||||
from .common import positional_encodings_like, INF, EPSILON, TransformerEncoder, TransformerDecoder, PackedLSTM, LSTMDecoderAttention, LSTMDecoder, Embedding, Feedforward, mask
|
||||
|
||||
|
@ -101,30 +100,28 @@ class PointerGenerator(nn.Module):
|
|||
effective_vocab_size = self.generative_vocab_size + len(oov_to_limited_idx)
|
||||
if self.generative_vocab_size < effective_vocab_size:
|
||||
size[-1] = effective_vocab_size - self.generative_vocab_size
|
||||
buff = Variable(scaled_p_vocab.data.new(*size).fill_(EPSILON))
|
||||
buff = scaled_p_vocab.new_full(size, EPSILON)
|
||||
scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim()-1)
|
||||
|
||||
p_context_ptr = Variable(scaled_p_vocab.data.new(*scaled_p_vocab.size()).fill_(EPSILON))
|
||||
p_context_ptr = scaled_p_vocab.new_full(scaled_p_vocab.size(), EPSILON)
|
||||
p_context_ptr.scatter_add_(p_context_ptr.dim()-1, context_indices.unsqueeze(1).expand_as(context_attention), context_attention)
|
||||
scaled_p_context_ptr = (1 - vocab_pointer_switches).expand_as(p_context_ptr) * p_context_ptr
|
||||
|
||||
probs = scaled_p_vocab + scaled_p_context_ptr #+ scaled_p_question_ptr
|
||||
probs = scaled_p_vocab + scaled_p_context_ptr
|
||||
return probs
|
||||
|
||||
|
||||
def greedy(self, context, context_indices, oov_to_limited_idx, rnn_state=None):
|
||||
B, TC, C = context.size()
|
||||
T = self.args.max_output_length
|
||||
outs = Variable(context.data.new(B, T).long().fill_(
|
||||
self.field.decoder_stoi['<pad>']), volatile=True)
|
||||
outs = context.new_full((B, T), self.field.decoder_stoi['<pad>'], dtype=torch.long)
|
||||
eos_yet = context.data.new(B).byte().zero_()
|
||||
|
||||
rnn_output, context_alignment = None, None
|
||||
for t in range(T):
|
||||
if t == 0:
|
||||
embedding = self.decoder_embeddings(Variable(
|
||||
context[-1].data.new(B).long().fill_(
|
||||
self.field.vocab.stoi['<init>']), volatile=True).unsqueeze(1), [1]*B)
|
||||
embedding = self.decoder_embeddings(
|
||||
self_attended_context[-1].new_full((B, 1), self.field.vocab.stoi['<init>'], dtype=torch.long), [1]*B)
|
||||
|
||||
else:
|
||||
embedding = self.decoder_embeddings(outs[:, t - 1].unsqueeze(1), [1]*B)
|
||||
decoder_outputs = self.dual_ptr_rnn_decoder(embedding, #hiddens[-1][:, t].unsqueeze(1),
|
||||
|
@ -138,8 +135,9 @@ class PointerGenerator(nn.Module):
|
|||
context_indices,
|
||||
oov_to_limited_idx)
|
||||
pred_probs, preds = probs.max(-1)
|
||||
eos_yet = eos_yet | (preds.data == self.field.decoder_stoi['<eos>'])
|
||||
outs[:, t] = Variable(preds.data.cpu().apply_(self.map_to_full), volatile=True)
|
||||
preds = preds.squeeze(1)
|
||||
eos_yet = eos_yet | (preds == self.field.decoder_stoi['<eos>'])
|
||||
outs[:, t] = preds.cpu().apply_(self.map_to_full)
|
||||
if eos_yet.all():
|
||||
break
|
||||
return outs
|
||||
|
|
Loading…
Reference in New Issue