Split MQAN model in encoder and decoder
This will make it easier to replace the encoder without touching the decoder.
This commit is contained in:
parent
c4a9c49d48
commit
61b8db12bf
|
@ -34,26 +34,95 @@ from ..util import get_trainable_params, set_seed
|
||||||
|
|
||||||
from .common import *
|
from .common import *
|
||||||
|
|
||||||
class MultitaskQuestionAnsweringNetwork(nn.Module):
|
class MQANEncoder(nn.Module):
|
||||||
|
|
||||||
def __init__(self, field, args):
|
def __init__(self, field, args):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.field = field
|
self.field = field
|
||||||
self.args = args
|
self.args = args
|
||||||
self.pad_idx = self.field.vocab.stoi[self.field.pad_token]
|
self.pad_idx = self.field.vocab.stoi[self.field.pad_token]
|
||||||
self.device = set_seed(args)
|
|
||||||
|
|
||||||
def dp(args):
|
|
||||||
return args.dropout_ratio if args.rnn_layers > 1 else 0.
|
|
||||||
|
|
||||||
if self.args.glove_and_char:
|
if self.args.glove_and_char:
|
||||||
|
|
||||||
self.encoder_embeddings = Embedding(field, args.dimension,
|
self.encoder_embeddings = Embedding(field, args.dimension,
|
||||||
trained_dimension=0,
|
trained_dimension=0,
|
||||||
dropout=args.dropout_ratio,
|
dropout=args.dropout_ratio,
|
||||||
project=True,
|
project=True,
|
||||||
requires_grad=args.retrain_encoder_embedding)
|
requires_grad=args.retrain_encoder_embedding)
|
||||||
|
|
||||||
|
def dp(args):
|
||||||
|
return args.dropout_ratio if args.rnn_layers > 1 else 0.
|
||||||
|
|
||||||
|
self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension,
|
||||||
|
batch_first=True, bidirectional=True, num_layers=1, dropout=0)
|
||||||
|
self.coattention = CoattentiveLayer(args.dimension, dropout=0.3)
|
||||||
|
dim = 2 * args.dimension + args.dimension + args.dimension
|
||||||
|
|
||||||
|
self.context_bilstm_after_coattention = PackedLSTM(dim, args.dimension,
|
||||||
|
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||||
|
num_layers=args.rnn_layers)
|
||||||
|
self.self_attentive_encoder_context = TransformerEncoder(args.dimension, args.transformer_heads,
|
||||||
|
args.transformer_hidden, args.transformer_layers,
|
||||||
|
args.dropout_ratio)
|
||||||
|
self.bilstm_context = PackedLSTM(args.dimension, args.dimension,
|
||||||
|
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||||
|
num_layers=args.rnn_layers)
|
||||||
|
|
||||||
|
self.question_bilstm_after_coattention = PackedLSTM(dim, args.dimension,
|
||||||
|
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||||
|
num_layers=args.rnn_layers)
|
||||||
|
self.self_attentive_encoder_question = TransformerEncoder(args.dimension, args.transformer_heads,
|
||||||
|
args.transformer_hidden, args.transformer_layers,
|
||||||
|
args.dropout_ratio)
|
||||||
|
self.bilstm_question = PackedLSTM(args.dimension, args.dimension,
|
||||||
|
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||||
|
num_layers=args.rnn_layers)
|
||||||
|
|
||||||
|
def set_embeddings(self, embeddings):
|
||||||
|
self.encoder_embeddings.set_embeddings(embeddings)
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
context, context_lengths, context_limited, context_tokens = batch.context.value, batch.context.length, batch.context.limited, batch.context.tokens
|
||||||
|
question, question_lengths, question_limited, question_tokens = batch.question.value, batch.question.length, batch.question.limited, batch.question.tokens
|
||||||
|
|
||||||
|
context_embedded = self.encoder_embeddings(context)
|
||||||
|
question_embedded = self.encoder_embeddings(question)
|
||||||
|
|
||||||
|
context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0]
|
||||||
|
question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0]
|
||||||
|
|
||||||
|
context_padding = context.data == self.pad_idx
|
||||||
|
question_padding = question.data == self.pad_idx
|
||||||
|
|
||||||
|
coattended_context, coattended_question = self.coattention(context_encoded, question_encoded,
|
||||||
|
context_padding, question_padding)
|
||||||
|
|
||||||
|
context_summary = torch.cat([coattended_context, context_encoded, context_embedded], -1)
|
||||||
|
condensed_context, _ = self.context_bilstm_after_coattention(context_summary, context_lengths)
|
||||||
|
self_attended_context = self.self_attentive_encoder_context(condensed_context, padding=context_padding)
|
||||||
|
final_context, (context_rnn_h, context_rnn_c) = self.bilstm_context(self_attended_context[-1],
|
||||||
|
context_lengths)
|
||||||
|
context_rnn_state = [self.reshape_rnn_state(x) for x in (context_rnn_h, context_rnn_c)]
|
||||||
|
|
||||||
|
question_summary = torch.cat([coattended_question, question_encoded, question_embedded], -1)
|
||||||
|
condensed_question, _ = self.question_bilstm_after_coattention(question_summary, question_lengths)
|
||||||
|
self_attended_question = self.self_attentive_encoder_question(condensed_question, padding=question_padding)
|
||||||
|
final_question, (question_rnn_h, question_rnn_c) = self.bilstm_question(self_attended_question[-1],
|
||||||
|
question_lengths)
|
||||||
|
question_rnn_state = [self.reshape_rnn_state(x) for x in (question_rnn_h, question_rnn_c)]
|
||||||
|
|
||||||
|
return self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state
|
||||||
|
|
||||||
|
def reshape_rnn_state(self, h):
|
||||||
|
return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \
|
||||||
|
.transpose(1, 2).contiguous() \
|
||||||
|
.view(h.size(0) // 2, h.size(1), h.size(2) * 2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class MQANDecoder(nn.Module):
|
||||||
|
def __init__(self, field, args):
|
||||||
|
super().__init__()
|
||||||
|
self.field = field
|
||||||
|
self.args = args
|
||||||
|
|
||||||
if args.pretrained_decoder_lm:
|
if args.pretrained_decoder_lm:
|
||||||
pretrained_save_dict = torch.load(os.path.join(args.embeddings, args.pretrained_decoder_lm), map_location=str(self.device))
|
pretrained_save_dict = torch.load(os.path.join(args.embeddings, args.pretrained_decoder_lm), map_location=str(self.device))
|
||||||
|
|
||||||
|
@ -86,27 +155,6 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
||||||
include_pretrained=args.glove_decoder,
|
include_pretrained=args.glove_decoder,
|
||||||
trained_dimension=args.trainable_decoder_embedding,
|
trained_dimension=args.trainable_decoder_embedding,
|
||||||
dropout=args.dropout_ratio, project=True)
|
dropout=args.dropout_ratio, project=True)
|
||||||
|
|
||||||
self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension,
|
|
||||||
batch_first=True, bidirectional=True, num_layers=1, dropout=0)
|
|
||||||
self.coattention = CoattentiveLayer(args.dimension, dropout=0.3)
|
|
||||||
dim = 2*args.dimension + args.dimension + args.dimension
|
|
||||||
|
|
||||||
self.context_bilstm_after_coattention = PackedLSTM(dim, args.dimension,
|
|
||||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
|
||||||
num_layers=args.rnn_layers)
|
|
||||||
self.self_attentive_encoder_context = TransformerEncoder(args.dimension, args.transformer_heads, args.transformer_hidden, args.transformer_layers, args.dropout_ratio)
|
|
||||||
self.bilstm_context = PackedLSTM(args.dimension, args.dimension,
|
|
||||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
|
||||||
num_layers=args.rnn_layers)
|
|
||||||
|
|
||||||
self.question_bilstm_after_coattention = PackedLSTM(dim, args.dimension,
|
|
||||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
|
||||||
num_layers=args.rnn_layers)
|
|
||||||
self.self_attentive_encoder_question = TransformerEncoder(args.dimension, args.transformer_heads, args.transformer_hidden, args.transformer_layers, args.dropout_ratio)
|
|
||||||
self.bilstm_question = PackedLSTM(args.dimension, args.dimension,
|
|
||||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
|
||||||
num_layers=args.rnn_layers)
|
|
||||||
|
|
||||||
self.self_attentive_decoder = TransformerDecoder(args.dimension, args.transformer_heads, args.transformer_hidden, args.transformer_layers, args.dropout_ratio)
|
self.self_attentive_decoder = TransformerDecoder(args.dimension, args.transformer_heads, args.transformer_hidden, args.transformer_layers, args.dropout_ratio)
|
||||||
self.dual_ptr_rnn_decoder = DualPtrRNNDecoder(args.dimension, args.dimension,
|
self.dual_ptr_rnn_decoder = DualPtrRNNDecoder(args.dimension, args.dimension,
|
||||||
|
@ -115,63 +163,39 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
||||||
self.generative_vocab_size = min(len(field.vocab), args.max_generative_vocab)
|
self.generative_vocab_size = min(len(field.vocab), args.max_generative_vocab)
|
||||||
self.out = nn.Linear(args.dimension, self.generative_vocab_size)
|
self.out = nn.Linear(args.dimension, self.generative_vocab_size)
|
||||||
|
|
||||||
self.dropout = nn.Dropout(0.4)
|
|
||||||
|
|
||||||
def set_embeddings(self, embeddings):
|
def set_embeddings(self, embeddings):
|
||||||
self.encoder_embeddings.set_embeddings(embeddings)
|
|
||||||
if self.decoder_embeddings is not None:
|
if self.decoder_embeddings is not None:
|
||||||
self.decoder_embeddings.set_embeddings(embeddings)
|
self.decoder_embeddings.set_embeddings(embeddings)
|
||||||
|
|
||||||
def forward(self, batch, iteration):
|
def forward(self, batch, self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state):
|
||||||
context, context_lengths, context_limited, context_tokens = batch.context.value, batch.context.length, batch.context.limited, batch.context.tokens
|
context, context_lengths, context_limited, context_tokens = batch.context.value, batch.context.length, batch.context.limited, batch.context.tokens
|
||||||
question, question_lengths, question_limited, question_tokens = batch.question.value, batch.question.length, batch.question.limited, batch.question.tokens
|
question, question_lengths, question_limited, question_tokens = batch.question.value, batch.question.length, batch.question.limited, batch.question.tokens
|
||||||
answer, answer_lengths, answer_limited, answer_tokens = batch.answer.value, batch.answer.length, batch.answer.limited, batch.answer.tokens
|
answer, answer_lengths, answer_limited, answer_tokens = batch.answer.value, batch.answer.length, batch.answer.limited, batch.answer.tokens
|
||||||
decoder_vocab = batch.decoder_vocab
|
decoder_vocab = batch.decoder_vocab
|
||||||
|
|
||||||
self.map_to_full = decoder_vocab.decode
|
self.map_to_full = decoder_vocab.decode
|
||||||
|
|
||||||
context_embedded = self.encoder_embeddings(context)
|
|
||||||
question_embedded = self.encoder_embeddings(question)
|
|
||||||
|
|
||||||
context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0]
|
|
||||||
question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0]
|
|
||||||
|
|
||||||
context_padding = context.data == self.pad_idx
|
|
||||||
question_padding = question.data == self.pad_idx
|
|
||||||
|
|
||||||
coattended_context, coattended_question = self.coattention(context_encoded, question_encoded, context_padding, question_padding)
|
|
||||||
|
|
||||||
context_summary = torch.cat([coattended_context, context_encoded, context_embedded], -1)
|
|
||||||
condensed_context, _ = self.context_bilstm_after_coattention(context_summary, context_lengths)
|
|
||||||
self_attended_context = self.self_attentive_encoder_context(condensed_context, padding=context_padding)
|
|
||||||
final_context, (context_rnn_h, context_rnn_c) = self.bilstm_context(self_attended_context[-1], context_lengths)
|
|
||||||
context_rnn_state = [self.reshape_rnn_state(x) for x in (context_rnn_h, context_rnn_c)]
|
|
||||||
|
|
||||||
question_summary = torch.cat([coattended_question, question_encoded, question_embedded], -1)
|
|
||||||
condensed_question, _ = self.question_bilstm_after_coattention(question_summary, question_lengths)
|
|
||||||
self_attended_question = self.self_attentive_encoder_question(condensed_question, padding=question_padding)
|
|
||||||
final_question, (question_rnn_h, question_rnn_c) = self.bilstm_question(self_attended_question[-1], question_lengths)
|
|
||||||
question_rnn_state = [self.reshape_rnn_state(x) for x in (question_rnn_h, question_rnn_c)]
|
|
||||||
|
|
||||||
context_indices = context_limited if context_limited is not None else context
|
context_indices = context_limited if context_limited is not None else context
|
||||||
question_indices = question_limited if question_limited is not None else question
|
question_indices = question_limited if question_limited is not None else question
|
||||||
answer_indices = answer_limited if answer_limited is not None else answer
|
answer_indices = answer_limited if answer_limited is not None else answer
|
||||||
|
|
||||||
pad_idx = self.field.decoder_vocab.stoi[self.field.pad_token]
|
decoder_pad_idx = self.field.decoder_vocab.stoi[self.field.pad_token]
|
||||||
context_padding = context_indices.data == pad_idx
|
context_padding = context_indices.data == decoder_pad_idx
|
||||||
question_padding = question_indices.data == pad_idx
|
question_padding = question_indices.data == decoder_pad_idx
|
||||||
|
|
||||||
self.dual_ptr_rnn_decoder.applyMasks(context_padding, question_padding)
|
self.dual_ptr_rnn_decoder.applyMasks(context_padding, question_padding)
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
answer_padding = (answer_indices.data == pad_idx)[:, :-1]
|
answer_padding = (answer_indices.data == decoder_pad_idx)[:, :-1]
|
||||||
|
|
||||||
if self.args.pretrained_decoder_lm:
|
if self.args.pretrained_decoder_lm:
|
||||||
# note that pretrained_decoder_embeddings is time first
|
# note that pretrained_decoder_embeddings is time first
|
||||||
answer_pretrained_numerical = [
|
answer_pretrained_numerical = [
|
||||||
[self.pretrained_decoder_vocab_stoi[sentence[time]] for sentence in answer_tokens] for time in range(len(answer_tokens[0]))
|
[self.pretrained_decoder_vocab_stoi[sentence[time]] for sentence in answer_tokens] for time in
|
||||||
|
range(len(answer_tokens[0]))
|
||||||
]
|
]
|
||||||
answer_pretrained_numerical = torch.tensor(answer_pretrained_numerical, dtype=torch.long, device=self.device)
|
answer_pretrained_numerical = torch.tensor(answer_pretrained_numerical, dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
answer_embedded, _ = self.pretrained_decoder_embeddings.encode(answer_pretrained_numerical)
|
answer_embedded, _ = self.pretrained_decoder_embeddings.encode(answer_pretrained_numerical)
|
||||||
|
@ -181,68 +205,68 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
||||||
answer_embedded = self.pretrained_decoder_embedding_projection(answer_embedded)
|
answer_embedded = self.pretrained_decoder_embedding_projection(answer_embedded)
|
||||||
else:
|
else:
|
||||||
answer_embedded = self.decoder_embeddings(answer)
|
answer_embedded = self.decoder_embeddings(answer)
|
||||||
self_attended_decoded = self.self_attentive_decoder(answer_embedded[:, :-1].contiguous(), self_attended_context, context_padding=context_padding, answer_padding=answer_padding, positional_encodings=True)
|
self_attended_decoded = self.self_attentive_decoder(answer_embedded[:, :-1].contiguous(),
|
||||||
decoder_outputs = self.dual_ptr_rnn_decoder(self_attended_decoded,
|
self_attended_context, context_padding=context_padding,
|
||||||
final_context, final_question, hidden=context_rnn_state)
|
answer_padding=answer_padding,
|
||||||
|
positional_encodings=True)
|
||||||
|
decoder_outputs = self.dual_ptr_rnn_decoder(self_attended_decoded,
|
||||||
|
final_context, final_question, hidden=context_rnn_state)
|
||||||
rnn_output, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switch, context_question_switch, rnn_state = decoder_outputs
|
rnn_output, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switch, context_question_switch, rnn_state = decoder_outputs
|
||||||
|
|
||||||
probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch,
|
probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch,
|
||||||
context_attention, question_attention,
|
context_attention, question_attention,
|
||||||
context_indices, question_indices,
|
context_indices, question_indices,
|
||||||
decoder_vocab)
|
decoder_vocab)
|
||||||
|
|
||||||
|
probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=decoder_pad_idx)
|
||||||
probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=pad_idx)
|
|
||||||
loss = F.nll_loss(probs.log(), targets)
|
loss = F.nll_loss(probs.log(), targets)
|
||||||
return loss, None
|
return loss, None
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return None, self.greedy(self_attended_context, final_context, final_question,
|
return None, self.greedy(self_attended_context, final_context, final_question,
|
||||||
context_indices, question_indices,
|
context_indices, question_indices,
|
||||||
decoder_vocab, rnn_state=context_rnn_state).data
|
decoder_vocab, rnn_state=context_rnn_state).data
|
||||||
|
|
||||||
def reshape_rnn_state(self, h):
|
|
||||||
return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \
|
|
||||||
.transpose(1, 2).contiguous() \
|
|
||||||
.view(h.size(0) // 2, h.size(1), h.size(2) * 2).contiguous()
|
|
||||||
|
|
||||||
def probs(self, generator, outputs, vocab_pointer_switches, context_question_switches,
|
def probs(self, generator, outputs, vocab_pointer_switches, context_question_switches,
|
||||||
context_attention, question_attention,
|
context_attention, question_attention,
|
||||||
context_indices, question_indices,
|
context_indices, question_indices,
|
||||||
decoder_vocab):
|
decoder_vocab):
|
||||||
|
|
||||||
size = list(outputs.size())
|
size = list(outputs.size())
|
||||||
|
|
||||||
size[-1] = self.generative_vocab_size
|
size[-1] = self.generative_vocab_size
|
||||||
scores = generator(outputs.view(-1, outputs.size(-1))).view(size)
|
scores = generator(outputs.view(-1, outputs.size(-1))).view(size)
|
||||||
p_vocab = F.softmax(scores, dim=scores.dim()-1)
|
p_vocab = F.softmax(scores, dim=scores.dim() - 1)
|
||||||
scaled_p_vocab = vocab_pointer_switches.expand_as(p_vocab) * p_vocab
|
scaled_p_vocab = vocab_pointer_switches.expand_as(p_vocab) * p_vocab
|
||||||
|
|
||||||
effective_vocab_size = len(decoder_vocab)
|
effective_vocab_size = len(decoder_vocab)
|
||||||
if self.generative_vocab_size < effective_vocab_size:
|
if self.generative_vocab_size < effective_vocab_size:
|
||||||
size[-1] = effective_vocab_size - self.generative_vocab_size
|
size[-1] = effective_vocab_size - self.generative_vocab_size
|
||||||
buff = scaled_p_vocab.new_full(size, EPSILON)
|
buff = scaled_p_vocab.new_full(size, EPSILON)
|
||||||
scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim()-1)
|
scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim() - 1)
|
||||||
|
|
||||||
# p_context_ptr
|
# p_context_ptr
|
||||||
scaled_p_vocab.scatter_add_(scaled_p_vocab.dim()-1, context_indices.unsqueeze(1).expand_as(context_attention),
|
scaled_p_vocab.scatter_add_(scaled_p_vocab.dim() - 1, context_indices.unsqueeze(1).expand_as(context_attention),
|
||||||
(context_question_switches * (1 - vocab_pointer_switches)).expand_as(context_attention) * context_attention)
|
(context_question_switches * (1 - vocab_pointer_switches)).expand_as(
|
||||||
|
context_attention) * context_attention)
|
||||||
|
|
||||||
# p_question_ptr
|
# p_question_ptr
|
||||||
scaled_p_vocab.scatter_add_(scaled_p_vocab.dim()-1, question_indices.unsqueeze(1).expand_as(question_attention),
|
scaled_p_vocab.scatter_add_(scaled_p_vocab.dim() - 1,
|
||||||
((1 - context_question_switches) * (1 - vocab_pointer_switches)).expand_as(question_attention) * question_attention)
|
question_indices.unsqueeze(1).expand_as(question_attention),
|
||||||
|
((1 - context_question_switches) * (1 - vocab_pointer_switches)).expand_as(
|
||||||
|
question_attention) * question_attention)
|
||||||
|
|
||||||
return scaled_p_vocab
|
return scaled_p_vocab
|
||||||
|
|
||||||
|
def greedy(self, self_attended_context, context, question, context_indices, question_indices, decoder_vocab,
|
||||||
def greedy(self, self_attended_context, context, question, context_indices, question_indices, decoder_vocab, rnn_state=None):
|
rnn_state=None):
|
||||||
B, TC, C = context.size()
|
B, TC, C = context.size()
|
||||||
T = self.args.max_output_length
|
T = self.args.max_output_length
|
||||||
outs = context.new_full((B, T), self.field.decoder_vocab.stoi[self.field.pad_token], dtype=torch.long)
|
outs = context.new_full((B, T), self.field.decoder_vocab.stoi[self.field.pad_token], dtype=torch.long)
|
||||||
hiddens = [self_attended_context[0].new_zeros((B, T, C))
|
hiddens = [self_attended_context[0].new_zeros((B, T, C))
|
||||||
for l in range(len(self.self_attentive_decoder.layers) + 1)]
|
for l in range(len(self.self_attentive_decoder.layers) + 1)]
|
||||||
hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0])
|
hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0])
|
||||||
eos_yet = context.new_zeros((B, )).byte()
|
eos_yet = context.new_zeros((B,)).byte()
|
||||||
|
|
||||||
pretrained_lm_hidden = None
|
pretrained_lm_hidden = None
|
||||||
if self.args.pretrained_decoder_lm:
|
if self.args.pretrained_decoder_lm:
|
||||||
|
@ -251,17 +275,21 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
if t == 0:
|
if t == 0:
|
||||||
if self.args.pretrained_decoder_lm:
|
if self.args.pretrained_decoder_lm:
|
||||||
init_token = self_attended_context[-1].new_full((1, B), self.pretrained_decoder_vocab_stoi['<init>'], dtype=torch.long)
|
init_token = self_attended_context[-1].new_full((1, B),
|
||||||
|
self.pretrained_decoder_vocab_stoi['<init>'],
|
||||||
|
dtype=torch.long)
|
||||||
|
|
||||||
# note that pretrained_decoder_embeddings is time first
|
# note that pretrained_decoder_embeddings is time first
|
||||||
embedding, pretrained_lm_hidden = self.pretrained_decoder_embeddings.encode(init_token, pretrained_lm_hidden)
|
embedding, pretrained_lm_hidden = self.pretrained_decoder_embeddings.encode(init_token,
|
||||||
|
pretrained_lm_hidden)
|
||||||
embedding.transpose_(0, 1)
|
embedding.transpose_(0, 1)
|
||||||
|
|
||||||
if self.pretrained_decoder_embedding_projection is not None:
|
if self.pretrained_decoder_embedding_projection is not None:
|
||||||
embedding = self.pretrained_decoder_embedding_projection(embedding)
|
embedding = self.pretrained_decoder_embedding_projection(embedding)
|
||||||
else:
|
else:
|
||||||
init_token = self_attended_context[-1].new_full((B, 1), self.field.vocab.stoi['<init>'], dtype=torch.long)
|
init_token = self_attended_context[-1].new_full((B, 1), self.field.vocab.stoi['<init>'],
|
||||||
embedding = self.decoder_embeddings(init_token, [1]*B)
|
dtype=torch.long)
|
||||||
|
embedding = self.decoder_embeddings(init_token, [1] * B)
|
||||||
else:
|
else:
|
||||||
if self.args.pretrained_decoder_lm:
|
if self.args.pretrained_decoder_lm:
|
||||||
current_token = [self.field.vocab.itos[x] for x in outs[:, t - 1]]
|
current_token = [self.field.vocab.itos[x] for x in outs[:, t - 1]]
|
||||||
|
@ -269,7 +297,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
||||||
dtype=torch.long, device=self.device, requires_grad=False)
|
dtype=torch.long, device=self.device, requires_grad=False)
|
||||||
embedding, pretrained_lm_hidden = self.pretrained_decoder_embeddings.encode(current_token_id,
|
embedding, pretrained_lm_hidden = self.pretrained_decoder_embeddings.encode(current_token_id,
|
||||||
pretrained_lm_hidden)
|
pretrained_lm_hidden)
|
||||||
|
|
||||||
# note that pretrained_decoder_embeddings is time first
|
# note that pretrained_decoder_embeddings is time first
|
||||||
embedding.transpose_(0, 1)
|
embedding.transpose_(0, 1)
|
||||||
|
|
||||||
|
@ -277,23 +305,26 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
||||||
embedding = self.pretrained_decoder_embedding_projection(embedding)
|
embedding = self.pretrained_decoder_embedding_projection(embedding)
|
||||||
else:
|
else:
|
||||||
current_token_id = outs[:, t - 1].unsqueeze(1)
|
current_token_id = outs[:, t - 1].unsqueeze(1)
|
||||||
embedding = self.decoder_embeddings(current_token_id, [1]*B)
|
embedding = self.decoder_embeddings(current_token_id, [1] * B)
|
||||||
|
|
||||||
hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze(1)
|
hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze(
|
||||||
|
1)
|
||||||
for l in range(len(self.self_attentive_decoder.layers)):
|
for l in range(len(self.self_attentive_decoder.layers)):
|
||||||
hiddens[l + 1][:, t] = self.self_attentive_decoder.layers[l].feedforward(
|
hiddens[l + 1][:, t] = self.self_attentive_decoder.layers[l].feedforward(
|
||||||
self.self_attentive_decoder.layers[l].attention(
|
self.self_attentive_decoder.layers[l].attention(
|
||||||
self.self_attentive_decoder.layers[l].selfattn(hiddens[l][:, t], hiddens[l][:, :t + 1], hiddens[l][:, :t + 1])
|
self.self_attentive_decoder.layers[l].selfattn(hiddens[l][:, t], hiddens[l][:, :t + 1],
|
||||||
, self_attended_context[l], self_attended_context[l]))
|
hiddens[l][:, :t + 1])
|
||||||
|
, self_attended_context[l], self_attended_context[l]))
|
||||||
decoder_outputs = self.dual_ptr_rnn_decoder(hiddens[-1][:, t].unsqueeze(1),
|
decoder_outputs = self.dual_ptr_rnn_decoder(hiddens[-1][:, t].unsqueeze(1),
|
||||||
context, question,
|
context, question,
|
||||||
context_alignment=context_alignment, question_alignment=question_alignment,
|
context_alignment=context_alignment,
|
||||||
hidden=rnn_state, output=rnn_output)
|
question_alignment=question_alignment,
|
||||||
|
hidden=rnn_state, output=rnn_output)
|
||||||
rnn_output, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switch, context_question_switch, rnn_state = decoder_outputs
|
rnn_output, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switch, context_question_switch, rnn_state = decoder_outputs
|
||||||
probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch,
|
probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch,
|
||||||
context_attention, question_attention,
|
context_attention, question_attention,
|
||||||
context_indices, question_indices,
|
context_indices, question_indices,
|
||||||
decoder_vocab)
|
decoder_vocab)
|
||||||
pred_probs, preds = probs.max(-1)
|
pred_probs, preds = probs.max(-1)
|
||||||
preds = preds.squeeze(1)
|
preds = preds.squeeze(1)
|
||||||
eos_yet = eos_yet | (preds == self.field.decoder_vocab.stoi[self.field.eos_token]).byte()
|
eos_yet = eos_yet | (preds == self.field.decoder_vocab.stoi[self.field.eos_token]).byte()
|
||||||
|
@ -303,6 +334,31 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
||||||
return outs
|
return outs
|
||||||
|
|
||||||
|
|
||||||
|
class MultitaskQuestionAnsweringNetwork(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, field, args):
|
||||||
|
super().__init__()
|
||||||
|
self.field = field
|
||||||
|
self.args = args
|
||||||
|
self.pad_idx = self.field.vocab.stoi[self.field.pad_token]
|
||||||
|
self.device = set_seed(args)
|
||||||
|
|
||||||
|
self.encoder = MQANEncoder(field, args)
|
||||||
|
self.decoder = MQANDecoder(field, args)
|
||||||
|
|
||||||
|
|
||||||
|
def set_embeddings(self, embeddings):
|
||||||
|
self.encoder.set_embeddings(embeddings)
|
||||||
|
self.decoder.set_embeddings(embeddings)
|
||||||
|
|
||||||
|
def forward(self, batch, iteration):
|
||||||
|
self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state = self.encoder(batch)
|
||||||
|
|
||||||
|
loss, predictions = self.decoder(batch, self_attended_context, final_context, context_rnn_state,
|
||||||
|
final_question, question_rnn_state)
|
||||||
|
|
||||||
|
return loss, predictions
|
||||||
|
|
||||||
|
|
||||||
class DualPtrRNNDecoder(nn.Module):
|
class DualPtrRNNDecoder(nn.Module):
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue