Split MQAN model in encoder and decoder
This will make it easier to replace the encoder without touching the decoder.
This commit is contained in:
parent
c4a9c49d48
commit
61b8db12bf
|
@ -34,26 +34,95 @@ from ..util import get_trainable_params, set_seed
|
|||
|
||||
from .common import *
|
||||
|
||||
class MultitaskQuestionAnsweringNetwork(nn.Module):
|
||||
|
||||
class MQANEncoder(nn.Module):
|
||||
def __init__(self, field, args):
|
||||
super().__init__()
|
||||
self.field = field
|
||||
self.args = args
|
||||
self.pad_idx = self.field.vocab.stoi[self.field.pad_token]
|
||||
self.device = set_seed(args)
|
||||
|
||||
def dp(args):
|
||||
return args.dropout_ratio if args.rnn_layers > 1 else 0.
|
||||
|
||||
if self.args.glove_and_char:
|
||||
|
||||
self.encoder_embeddings = Embedding(field, args.dimension,
|
||||
trained_dimension=0,
|
||||
dropout=args.dropout_ratio,
|
||||
project=True,
|
||||
requires_grad=args.retrain_encoder_embedding)
|
||||
|
||||
|
||||
def dp(args):
|
||||
return args.dropout_ratio if args.rnn_layers > 1 else 0.
|
||||
|
||||
self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension,
|
||||
batch_first=True, bidirectional=True, num_layers=1, dropout=0)
|
||||
self.coattention = CoattentiveLayer(args.dimension, dropout=0.3)
|
||||
dim = 2 * args.dimension + args.dimension + args.dimension
|
||||
|
||||
self.context_bilstm_after_coattention = PackedLSTM(dim, args.dimension,
|
||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||
num_layers=args.rnn_layers)
|
||||
self.self_attentive_encoder_context = TransformerEncoder(args.dimension, args.transformer_heads,
|
||||
args.transformer_hidden, args.transformer_layers,
|
||||
args.dropout_ratio)
|
||||
self.bilstm_context = PackedLSTM(args.dimension, args.dimension,
|
||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||
num_layers=args.rnn_layers)
|
||||
|
||||
self.question_bilstm_after_coattention = PackedLSTM(dim, args.dimension,
|
||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||
num_layers=args.rnn_layers)
|
||||
self.self_attentive_encoder_question = TransformerEncoder(args.dimension, args.transformer_heads,
|
||||
args.transformer_hidden, args.transformer_layers,
|
||||
args.dropout_ratio)
|
||||
self.bilstm_question = PackedLSTM(args.dimension, args.dimension,
|
||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||
num_layers=args.rnn_layers)
|
||||
|
||||
def set_embeddings(self, embeddings):
|
||||
self.encoder_embeddings.set_embeddings(embeddings)
|
||||
|
||||
def forward(self, batch):
|
||||
context, context_lengths, context_limited, context_tokens = batch.context.value, batch.context.length, batch.context.limited, batch.context.tokens
|
||||
question, question_lengths, question_limited, question_tokens = batch.question.value, batch.question.length, batch.question.limited, batch.question.tokens
|
||||
|
||||
context_embedded = self.encoder_embeddings(context)
|
||||
question_embedded = self.encoder_embeddings(question)
|
||||
|
||||
context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0]
|
||||
question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0]
|
||||
|
||||
context_padding = context.data == self.pad_idx
|
||||
question_padding = question.data == self.pad_idx
|
||||
|
||||
coattended_context, coattended_question = self.coattention(context_encoded, question_encoded,
|
||||
context_padding, question_padding)
|
||||
|
||||
context_summary = torch.cat([coattended_context, context_encoded, context_embedded], -1)
|
||||
condensed_context, _ = self.context_bilstm_after_coattention(context_summary, context_lengths)
|
||||
self_attended_context = self.self_attentive_encoder_context(condensed_context, padding=context_padding)
|
||||
final_context, (context_rnn_h, context_rnn_c) = self.bilstm_context(self_attended_context[-1],
|
||||
context_lengths)
|
||||
context_rnn_state = [self.reshape_rnn_state(x) for x in (context_rnn_h, context_rnn_c)]
|
||||
|
||||
question_summary = torch.cat([coattended_question, question_encoded, question_embedded], -1)
|
||||
condensed_question, _ = self.question_bilstm_after_coattention(question_summary, question_lengths)
|
||||
self_attended_question = self.self_attentive_encoder_question(condensed_question, padding=question_padding)
|
||||
final_question, (question_rnn_h, question_rnn_c) = self.bilstm_question(self_attended_question[-1],
|
||||
question_lengths)
|
||||
question_rnn_state = [self.reshape_rnn_state(x) for x in (question_rnn_h, question_rnn_c)]
|
||||
|
||||
return self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state
|
||||
|
||||
def reshape_rnn_state(self, h):
|
||||
return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \
|
||||
.transpose(1, 2).contiguous() \
|
||||
.view(h.size(0) // 2, h.size(1), h.size(2) * 2).contiguous()
|
||||
|
||||
|
||||
class MQANDecoder(nn.Module):
|
||||
def __init__(self, field, args):
|
||||
super().__init__()
|
||||
self.field = field
|
||||
self.args = args
|
||||
|
||||
if args.pretrained_decoder_lm:
|
||||
pretrained_save_dict = torch.load(os.path.join(args.embeddings, args.pretrained_decoder_lm), map_location=str(self.device))
|
||||
|
||||
|
@ -86,27 +155,6 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
include_pretrained=args.glove_decoder,
|
||||
trained_dimension=args.trainable_decoder_embedding,
|
||||
dropout=args.dropout_ratio, project=True)
|
||||
|
||||
self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension,
|
||||
batch_first=True, bidirectional=True, num_layers=1, dropout=0)
|
||||
self.coattention = CoattentiveLayer(args.dimension, dropout=0.3)
|
||||
dim = 2*args.dimension + args.dimension + args.dimension
|
||||
|
||||
self.context_bilstm_after_coattention = PackedLSTM(dim, args.dimension,
|
||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||
num_layers=args.rnn_layers)
|
||||
self.self_attentive_encoder_context = TransformerEncoder(args.dimension, args.transformer_heads, args.transformer_hidden, args.transformer_layers, args.dropout_ratio)
|
||||
self.bilstm_context = PackedLSTM(args.dimension, args.dimension,
|
||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||
num_layers=args.rnn_layers)
|
||||
|
||||
self.question_bilstm_after_coattention = PackedLSTM(dim, args.dimension,
|
||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||
num_layers=args.rnn_layers)
|
||||
self.self_attentive_encoder_question = TransformerEncoder(args.dimension, args.transformer_heads, args.transformer_hidden, args.transformer_layers, args.dropout_ratio)
|
||||
self.bilstm_question = PackedLSTM(args.dimension, args.dimension,
|
||||
batch_first=True, dropout=dp(args), bidirectional=True,
|
||||
num_layers=args.rnn_layers)
|
||||
|
||||
self.self_attentive_decoder = TransformerDecoder(args.dimension, args.transformer_heads, args.transformer_hidden, args.transformer_layers, args.dropout_ratio)
|
||||
self.dual_ptr_rnn_decoder = DualPtrRNNDecoder(args.dimension, args.dimension,
|
||||
|
@ -115,63 +163,39 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
self.generative_vocab_size = min(len(field.vocab), args.max_generative_vocab)
|
||||
self.out = nn.Linear(args.dimension, self.generative_vocab_size)
|
||||
|
||||
self.dropout = nn.Dropout(0.4)
|
||||
|
||||
def set_embeddings(self, embeddings):
|
||||
self.encoder_embeddings.set_embeddings(embeddings)
|
||||
if self.decoder_embeddings is not None:
|
||||
self.decoder_embeddings.set_embeddings(embeddings)
|
||||
|
||||
def forward(self, batch, iteration):
|
||||
context, context_lengths, context_limited, context_tokens = batch.context.value, batch.context.length, batch.context.limited, batch.context.tokens
|
||||
def forward(self, batch, self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state):
|
||||
context, context_lengths, context_limited, context_tokens = batch.context.value, batch.context.length, batch.context.limited, batch.context.tokens
|
||||
question, question_lengths, question_limited, question_tokens = batch.question.value, batch.question.length, batch.question.limited, batch.question.tokens
|
||||
answer, answer_lengths, answer_limited, answer_tokens = batch.answer.value, batch.answer.length, batch.answer.limited, batch.answer.tokens
|
||||
decoder_vocab = batch.decoder_vocab
|
||||
answer, answer_lengths, answer_limited, answer_tokens = batch.answer.value, batch.answer.length, batch.answer.limited, batch.answer.tokens
|
||||
decoder_vocab = batch.decoder_vocab
|
||||
|
||||
self.map_to_full = decoder_vocab.decode
|
||||
|
||||
context_embedded = self.encoder_embeddings(context)
|
||||
question_embedded = self.encoder_embeddings(question)
|
||||
|
||||
context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0]
|
||||
question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0]
|
||||
|
||||
context_padding = context.data == self.pad_idx
|
||||
question_padding = question.data == self.pad_idx
|
||||
|
||||
coattended_context, coattended_question = self.coattention(context_encoded, question_encoded, context_padding, question_padding)
|
||||
|
||||
context_summary = torch.cat([coattended_context, context_encoded, context_embedded], -1)
|
||||
condensed_context, _ = self.context_bilstm_after_coattention(context_summary, context_lengths)
|
||||
self_attended_context = self.self_attentive_encoder_context(condensed_context, padding=context_padding)
|
||||
final_context, (context_rnn_h, context_rnn_c) = self.bilstm_context(self_attended_context[-1], context_lengths)
|
||||
context_rnn_state = [self.reshape_rnn_state(x) for x in (context_rnn_h, context_rnn_c)]
|
||||
|
||||
question_summary = torch.cat([coattended_question, question_encoded, question_embedded], -1)
|
||||
condensed_question, _ = self.question_bilstm_after_coattention(question_summary, question_lengths)
|
||||
self_attended_question = self.self_attentive_encoder_question(condensed_question, padding=question_padding)
|
||||
final_question, (question_rnn_h, question_rnn_c) = self.bilstm_question(self_attended_question[-1], question_lengths)
|
||||
question_rnn_state = [self.reshape_rnn_state(x) for x in (question_rnn_h, question_rnn_c)]
|
||||
|
||||
context_indices = context_limited if context_limited is not None else context
|
||||
question_indices = question_limited if question_limited is not None else question
|
||||
answer_indices = answer_limited if answer_limited is not None else answer
|
||||
|
||||
pad_idx = self.field.decoder_vocab.stoi[self.field.pad_token]
|
||||
context_padding = context_indices.data == pad_idx
|
||||
question_padding = question_indices.data == pad_idx
|
||||
decoder_pad_idx = self.field.decoder_vocab.stoi[self.field.pad_token]
|
||||
context_padding = context_indices.data == decoder_pad_idx
|
||||
question_padding = question_indices.data == decoder_pad_idx
|
||||
|
||||
self.dual_ptr_rnn_decoder.applyMasks(context_padding, question_padding)
|
||||
|
||||
if self.training:
|
||||
answer_padding = (answer_indices.data == pad_idx)[:, :-1]
|
||||
answer_padding = (answer_indices.data == decoder_pad_idx)[:, :-1]
|
||||
|
||||
if self.args.pretrained_decoder_lm:
|
||||
# note that pretrained_decoder_embeddings is time first
|
||||
answer_pretrained_numerical = [
|
||||
[self.pretrained_decoder_vocab_stoi[sentence[time]] for sentence in answer_tokens] for time in range(len(answer_tokens[0]))
|
||||
[self.pretrained_decoder_vocab_stoi[sentence[time]] for sentence in answer_tokens] for time in
|
||||
range(len(answer_tokens[0]))
|
||||
]
|
||||
answer_pretrained_numerical = torch.tensor(answer_pretrained_numerical, dtype=torch.long, device=self.device)
|
||||
answer_pretrained_numerical = torch.tensor(answer_pretrained_numerical, dtype=torch.long,
|
||||
device=self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
answer_embedded, _ = self.pretrained_decoder_embeddings.encode(answer_pretrained_numerical)
|
||||
|
@ -181,68 +205,68 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
answer_embedded = self.pretrained_decoder_embedding_projection(answer_embedded)
|
||||
else:
|
||||
answer_embedded = self.decoder_embeddings(answer)
|
||||
self_attended_decoded = self.self_attentive_decoder(answer_embedded[:, :-1].contiguous(), self_attended_context, context_padding=context_padding, answer_padding=answer_padding, positional_encodings=True)
|
||||
decoder_outputs = self.dual_ptr_rnn_decoder(self_attended_decoded,
|
||||
final_context, final_question, hidden=context_rnn_state)
|
||||
self_attended_decoded = self.self_attentive_decoder(answer_embedded[:, :-1].contiguous(),
|
||||
self_attended_context, context_padding=context_padding,
|
||||
answer_padding=answer_padding,
|
||||
positional_encodings=True)
|
||||
decoder_outputs = self.dual_ptr_rnn_decoder(self_attended_decoded,
|
||||
final_context, final_question, hidden=context_rnn_state)
|
||||
rnn_output, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switch, context_question_switch, rnn_state = decoder_outputs
|
||||
|
||||
probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch,
|
||||
context_attention, question_attention,
|
||||
context_indices, question_indices,
|
||||
decoder_vocab)
|
||||
probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch,
|
||||
context_attention, question_attention,
|
||||
context_indices, question_indices,
|
||||
decoder_vocab)
|
||||
|
||||
|
||||
probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=pad_idx)
|
||||
probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=decoder_pad_idx)
|
||||
loss = F.nll_loss(probs.log(), targets)
|
||||
return loss, None
|
||||
|
||||
else:
|
||||
return None, self.greedy(self_attended_context, final_context, final_question,
|
||||
context_indices, question_indices,
|
||||
decoder_vocab, rnn_state=context_rnn_state).data
|
||||
|
||||
def reshape_rnn_state(self, h):
|
||||
return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \
|
||||
.transpose(1, 2).contiguous() \
|
||||
.view(h.size(0) // 2, h.size(1), h.size(2) * 2).contiguous()
|
||||
return None, self.greedy(self_attended_context, final_context, final_question,
|
||||
context_indices, question_indices,
|
||||
decoder_vocab, rnn_state=context_rnn_state).data
|
||||
|
||||
def probs(self, generator, outputs, vocab_pointer_switches, context_question_switches,
|
||||
context_attention, question_attention,
|
||||
context_indices, question_indices,
|
||||
decoder_vocab):
|
||||
def probs(self, generator, outputs, vocab_pointer_switches, context_question_switches,
|
||||
context_attention, question_attention,
|
||||
context_indices, question_indices,
|
||||
decoder_vocab):
|
||||
|
||||
size = list(outputs.size())
|
||||
|
||||
size[-1] = self.generative_vocab_size
|
||||
scores = generator(outputs.view(-1, outputs.size(-1))).view(size)
|
||||
p_vocab = F.softmax(scores, dim=scores.dim()-1)
|
||||
p_vocab = F.softmax(scores, dim=scores.dim() - 1)
|
||||
scaled_p_vocab = vocab_pointer_switches.expand_as(p_vocab) * p_vocab
|
||||
|
||||
effective_vocab_size = len(decoder_vocab)
|
||||
if self.generative_vocab_size < effective_vocab_size:
|
||||
size[-1] = effective_vocab_size - self.generative_vocab_size
|
||||
buff = scaled_p_vocab.new_full(size, EPSILON)
|
||||
scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim()-1)
|
||||
scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim() - 1)
|
||||
|
||||
# p_context_ptr
|
||||
scaled_p_vocab.scatter_add_(scaled_p_vocab.dim()-1, context_indices.unsqueeze(1).expand_as(context_attention),
|
||||
(context_question_switches * (1 - vocab_pointer_switches)).expand_as(context_attention) * context_attention)
|
||||
scaled_p_vocab.scatter_add_(scaled_p_vocab.dim() - 1, context_indices.unsqueeze(1).expand_as(context_attention),
|
||||
(context_question_switches * (1 - vocab_pointer_switches)).expand_as(
|
||||
context_attention) * context_attention)
|
||||
|
||||
# p_question_ptr
|
||||
scaled_p_vocab.scatter_add_(scaled_p_vocab.dim()-1, question_indices.unsqueeze(1).expand_as(question_attention),
|
||||
((1 - context_question_switches) * (1 - vocab_pointer_switches)).expand_as(question_attention) * question_attention)
|
||||
scaled_p_vocab.scatter_add_(scaled_p_vocab.dim() - 1,
|
||||
question_indices.unsqueeze(1).expand_as(question_attention),
|
||||
((1 - context_question_switches) * (1 - vocab_pointer_switches)).expand_as(
|
||||
question_attention) * question_attention)
|
||||
|
||||
return scaled_p_vocab
|
||||
|
||||
|
||||
def greedy(self, self_attended_context, context, question, context_indices, question_indices, decoder_vocab, rnn_state=None):
|
||||
def greedy(self, self_attended_context, context, question, context_indices, question_indices, decoder_vocab,
|
||||
rnn_state=None):
|
||||
B, TC, C = context.size()
|
||||
T = self.args.max_output_length
|
||||
outs = context.new_full((B, T), self.field.decoder_vocab.stoi[self.field.pad_token], dtype=torch.long)
|
||||
hiddens = [self_attended_context[0].new_zeros((B, T, C))
|
||||
for l in range(len(self.self_attentive_decoder.layers) + 1)]
|
||||
hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0])
|
||||
eos_yet = context.new_zeros((B, )).byte()
|
||||
eos_yet = context.new_zeros((B,)).byte()
|
||||
|
||||
pretrained_lm_hidden = None
|
||||
if self.args.pretrained_decoder_lm:
|
||||
|
@ -251,17 +275,21 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
for t in range(T):
|
||||
if t == 0:
|
||||
if self.args.pretrained_decoder_lm:
|
||||
init_token = self_attended_context[-1].new_full((1, B), self.pretrained_decoder_vocab_stoi['<init>'], dtype=torch.long)
|
||||
|
||||
init_token = self_attended_context[-1].new_full((1, B),
|
||||
self.pretrained_decoder_vocab_stoi['<init>'],
|
||||
dtype=torch.long)
|
||||
|
||||
# note that pretrained_decoder_embeddings is time first
|
||||
embedding, pretrained_lm_hidden = self.pretrained_decoder_embeddings.encode(init_token, pretrained_lm_hidden)
|
||||
embedding, pretrained_lm_hidden = self.pretrained_decoder_embeddings.encode(init_token,
|
||||
pretrained_lm_hidden)
|
||||
embedding.transpose_(0, 1)
|
||||
|
||||
if self.pretrained_decoder_embedding_projection is not None:
|
||||
embedding = self.pretrained_decoder_embedding_projection(embedding)
|
||||
else:
|
||||
init_token = self_attended_context[-1].new_full((B, 1), self.field.vocab.stoi['<init>'], dtype=torch.long)
|
||||
embedding = self.decoder_embeddings(init_token, [1]*B)
|
||||
init_token = self_attended_context[-1].new_full((B, 1), self.field.vocab.stoi['<init>'],
|
||||
dtype=torch.long)
|
||||
embedding = self.decoder_embeddings(init_token, [1] * B)
|
||||
else:
|
||||
if self.args.pretrained_decoder_lm:
|
||||
current_token = [self.field.vocab.itos[x] for x in outs[:, t - 1]]
|
||||
|
@ -269,7 +297,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
dtype=torch.long, device=self.device, requires_grad=False)
|
||||
embedding, pretrained_lm_hidden = self.pretrained_decoder_embeddings.encode(current_token_id,
|
||||
pretrained_lm_hidden)
|
||||
|
||||
|
||||
# note that pretrained_decoder_embeddings is time first
|
||||
embedding.transpose_(0, 1)
|
||||
|
||||
|
@ -277,23 +305,26 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
embedding = self.pretrained_decoder_embedding_projection(embedding)
|
||||
else:
|
||||
current_token_id = outs[:, t - 1].unsqueeze(1)
|
||||
embedding = self.decoder_embeddings(current_token_id, [1]*B)
|
||||
embedding = self.decoder_embeddings(current_token_id, [1] * B)
|
||||
|
||||
hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze(1)
|
||||
hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze(
|
||||
1)
|
||||
for l in range(len(self.self_attentive_decoder.layers)):
|
||||
hiddens[l + 1][:, t] = self.self_attentive_decoder.layers[l].feedforward(
|
||||
self.self_attentive_decoder.layers[l].attention(
|
||||
self.self_attentive_decoder.layers[l].selfattn(hiddens[l][:, t], hiddens[l][:, :t + 1], hiddens[l][:, :t + 1])
|
||||
, self_attended_context[l], self_attended_context[l]))
|
||||
self.self_attentive_decoder.layers[l].selfattn(hiddens[l][:, t], hiddens[l][:, :t + 1],
|
||||
hiddens[l][:, :t + 1])
|
||||
, self_attended_context[l], self_attended_context[l]))
|
||||
decoder_outputs = self.dual_ptr_rnn_decoder(hiddens[-1][:, t].unsqueeze(1),
|
||||
context, question,
|
||||
context_alignment=context_alignment, question_alignment=question_alignment,
|
||||
hidden=rnn_state, output=rnn_output)
|
||||
context, question,
|
||||
context_alignment=context_alignment,
|
||||
question_alignment=question_alignment,
|
||||
hidden=rnn_state, output=rnn_output)
|
||||
rnn_output, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switch, context_question_switch, rnn_state = decoder_outputs
|
||||
probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch,
|
||||
context_attention, question_attention,
|
||||
context_indices, question_indices,
|
||||
decoder_vocab)
|
||||
probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch,
|
||||
context_attention, question_attention,
|
||||
context_indices, question_indices,
|
||||
decoder_vocab)
|
||||
pred_probs, preds = probs.max(-1)
|
||||
preds = preds.squeeze(1)
|
||||
eos_yet = eos_yet | (preds == self.field.decoder_vocab.stoi[self.field.eos_token]).byte()
|
||||
|
@ -303,6 +334,31 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
return outs
|
||||
|
||||
|
||||
class MultitaskQuestionAnsweringNetwork(nn.Module):
|
||||
|
||||
def __init__(self, field, args):
|
||||
super().__init__()
|
||||
self.field = field
|
||||
self.args = args
|
||||
self.pad_idx = self.field.vocab.stoi[self.field.pad_token]
|
||||
self.device = set_seed(args)
|
||||
|
||||
self.encoder = MQANEncoder(field, args)
|
||||
self.decoder = MQANDecoder(field, args)
|
||||
|
||||
|
||||
def set_embeddings(self, embeddings):
|
||||
self.encoder.set_embeddings(embeddings)
|
||||
self.decoder.set_embeddings(embeddings)
|
||||
|
||||
def forward(self, batch, iteration):
|
||||
self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state = self.encoder(batch)
|
||||
|
||||
loss, predictions = self.decoder(batch, self_attended_context, final_context, context_rnn_state,
|
||||
final_question, question_rnn_state)
|
||||
|
||||
return loss, predictions
|
||||
|
||||
|
||||
class DualPtrRNNDecoder(nn.Module):
|
||||
|
||||
|
|
Loading…
Reference in New Issue