diff --git a/decanlp/models/multitask_question_answering_network.py b/decanlp/models/multitask_question_answering_network.py index ebabbb56..f8c7ea50 100644 --- a/decanlp/models/multitask_question_answering_network.py +++ b/decanlp/models/multitask_question_answering_network.py @@ -34,26 +34,95 @@ from ..util import get_trainable_params, set_seed from .common import * -class MultitaskQuestionAnsweringNetwork(nn.Module): - +class MQANEncoder(nn.Module): def __init__(self, field, args): super().__init__() self.field = field self.args = args self.pad_idx = self.field.vocab.stoi[self.field.pad_token] - self.device = set_seed(args) - - def dp(args): - return args.dropout_ratio if args.rnn_layers > 1 else 0. if self.args.glove_and_char: - self.encoder_embeddings = Embedding(field, args.dimension, trained_dimension=0, dropout=args.dropout_ratio, project=True, requires_grad=args.retrain_encoder_embedding) - + + def dp(args): + return args.dropout_ratio if args.rnn_layers > 1 else 0. + + self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension, + batch_first=True, bidirectional=True, num_layers=1, dropout=0) + self.coattention = CoattentiveLayer(args.dimension, dropout=0.3) + dim = 2 * args.dimension + args.dimension + args.dimension + + self.context_bilstm_after_coattention = PackedLSTM(dim, args.dimension, + batch_first=True, dropout=dp(args), bidirectional=True, + num_layers=args.rnn_layers) + self.self_attentive_encoder_context = TransformerEncoder(args.dimension, args.transformer_heads, + args.transformer_hidden, args.transformer_layers, + args.dropout_ratio) + self.bilstm_context = PackedLSTM(args.dimension, args.dimension, + batch_first=True, dropout=dp(args), bidirectional=True, + num_layers=args.rnn_layers) + + self.question_bilstm_after_coattention = PackedLSTM(dim, args.dimension, + batch_first=True, dropout=dp(args), bidirectional=True, + num_layers=args.rnn_layers) + self.self_attentive_encoder_question = TransformerEncoder(args.dimension, args.transformer_heads, + args.transformer_hidden, args.transformer_layers, + args.dropout_ratio) + self.bilstm_question = PackedLSTM(args.dimension, args.dimension, + batch_first=True, dropout=dp(args), bidirectional=True, + num_layers=args.rnn_layers) + + def set_embeddings(self, embeddings): + self.encoder_embeddings.set_embeddings(embeddings) + + def forward(self, batch): + context, context_lengths, context_limited, context_tokens = batch.context.value, batch.context.length, batch.context.limited, batch.context.tokens + question, question_lengths, question_limited, question_tokens = batch.question.value, batch.question.length, batch.question.limited, batch.question.tokens + + context_embedded = self.encoder_embeddings(context) + question_embedded = self.encoder_embeddings(question) + + context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0] + question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0] + + context_padding = context.data == self.pad_idx + question_padding = question.data == self.pad_idx + + coattended_context, coattended_question = self.coattention(context_encoded, question_encoded, + context_padding, question_padding) + + context_summary = torch.cat([coattended_context, context_encoded, context_embedded], -1) + condensed_context, _ = self.context_bilstm_after_coattention(context_summary, context_lengths) + self_attended_context = self.self_attentive_encoder_context(condensed_context, padding=context_padding) + final_context, (context_rnn_h, context_rnn_c) = self.bilstm_context(self_attended_context[-1], + context_lengths) + context_rnn_state = [self.reshape_rnn_state(x) for x in (context_rnn_h, context_rnn_c)] + + question_summary = torch.cat([coattended_question, question_encoded, question_embedded], -1) + condensed_question, _ = self.question_bilstm_after_coattention(question_summary, question_lengths) + self_attended_question = self.self_attentive_encoder_question(condensed_question, padding=question_padding) + final_question, (question_rnn_h, question_rnn_c) = self.bilstm_question(self_attended_question[-1], + question_lengths) + question_rnn_state = [self.reshape_rnn_state(x) for x in (question_rnn_h, question_rnn_c)] + + return self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state + + def reshape_rnn_state(self, h): + return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \ + .transpose(1, 2).contiguous() \ + .view(h.size(0) // 2, h.size(1), h.size(2) * 2).contiguous() + + +class MQANDecoder(nn.Module): + def __init__(self, field, args): + super().__init__() + self.field = field + self.args = args + if args.pretrained_decoder_lm: pretrained_save_dict = torch.load(os.path.join(args.embeddings, args.pretrained_decoder_lm), map_location=str(self.device)) @@ -86,27 +155,6 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): include_pretrained=args.glove_decoder, trained_dimension=args.trainable_decoder_embedding, dropout=args.dropout_ratio, project=True) - - self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension, - batch_first=True, bidirectional=True, num_layers=1, dropout=0) - self.coattention = CoattentiveLayer(args.dimension, dropout=0.3) - dim = 2*args.dimension + args.dimension + args.dimension - - self.context_bilstm_after_coattention = PackedLSTM(dim, args.dimension, - batch_first=True, dropout=dp(args), bidirectional=True, - num_layers=args.rnn_layers) - self.self_attentive_encoder_context = TransformerEncoder(args.dimension, args.transformer_heads, args.transformer_hidden, args.transformer_layers, args.dropout_ratio) - self.bilstm_context = PackedLSTM(args.dimension, args.dimension, - batch_first=True, dropout=dp(args), bidirectional=True, - num_layers=args.rnn_layers) - - self.question_bilstm_after_coattention = PackedLSTM(dim, args.dimension, - batch_first=True, dropout=dp(args), bidirectional=True, - num_layers=args.rnn_layers) - self.self_attentive_encoder_question = TransformerEncoder(args.dimension, args.transformer_heads, args.transformer_hidden, args.transformer_layers, args.dropout_ratio) - self.bilstm_question = PackedLSTM(args.dimension, args.dimension, - batch_first=True, dropout=dp(args), bidirectional=True, - num_layers=args.rnn_layers) self.self_attentive_decoder = TransformerDecoder(args.dimension, args.transformer_heads, args.transformer_hidden, args.transformer_layers, args.dropout_ratio) self.dual_ptr_rnn_decoder = DualPtrRNNDecoder(args.dimension, args.dimension, @@ -115,63 +163,39 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): self.generative_vocab_size = min(len(field.vocab), args.max_generative_vocab) self.out = nn.Linear(args.dimension, self.generative_vocab_size) - self.dropout = nn.Dropout(0.4) - def set_embeddings(self, embeddings): - self.encoder_embeddings.set_embeddings(embeddings) if self.decoder_embeddings is not None: self.decoder_embeddings.set_embeddings(embeddings) - def forward(self, batch, iteration): - context, context_lengths, context_limited, context_tokens = batch.context.value, batch.context.length, batch.context.limited, batch.context.tokens + def forward(self, batch, self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state): + context, context_lengths, context_limited, context_tokens = batch.context.value, batch.context.length, batch.context.limited, batch.context.tokens question, question_lengths, question_limited, question_tokens = batch.question.value, batch.question.length, batch.question.limited, batch.question.tokens - answer, answer_lengths, answer_limited, answer_tokens = batch.answer.value, batch.answer.length, batch.answer.limited, batch.answer.tokens - decoder_vocab = batch.decoder_vocab + answer, answer_lengths, answer_limited, answer_tokens = batch.answer.value, batch.answer.length, batch.answer.limited, batch.answer.tokens + decoder_vocab = batch.decoder_vocab self.map_to_full = decoder_vocab.decode - context_embedded = self.encoder_embeddings(context) - question_embedded = self.encoder_embeddings(question) - - context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0] - question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0] - - context_padding = context.data == self.pad_idx - question_padding = question.data == self.pad_idx - - coattended_context, coattended_question = self.coattention(context_encoded, question_encoded, context_padding, question_padding) - - context_summary = torch.cat([coattended_context, context_encoded, context_embedded], -1) - condensed_context, _ = self.context_bilstm_after_coattention(context_summary, context_lengths) - self_attended_context = self.self_attentive_encoder_context(condensed_context, padding=context_padding) - final_context, (context_rnn_h, context_rnn_c) = self.bilstm_context(self_attended_context[-1], context_lengths) - context_rnn_state = [self.reshape_rnn_state(x) for x in (context_rnn_h, context_rnn_c)] - - question_summary = torch.cat([coattended_question, question_encoded, question_embedded], -1) - condensed_question, _ = self.question_bilstm_after_coattention(question_summary, question_lengths) - self_attended_question = self.self_attentive_encoder_question(condensed_question, padding=question_padding) - final_question, (question_rnn_h, question_rnn_c) = self.bilstm_question(self_attended_question[-1], question_lengths) - question_rnn_state = [self.reshape_rnn_state(x) for x in (question_rnn_h, question_rnn_c)] - context_indices = context_limited if context_limited is not None else context question_indices = question_limited if question_limited is not None else question answer_indices = answer_limited if answer_limited is not None else answer - pad_idx = self.field.decoder_vocab.stoi[self.field.pad_token] - context_padding = context_indices.data == pad_idx - question_padding = question_indices.data == pad_idx + decoder_pad_idx = self.field.decoder_vocab.stoi[self.field.pad_token] + context_padding = context_indices.data == decoder_pad_idx + question_padding = question_indices.data == decoder_pad_idx self.dual_ptr_rnn_decoder.applyMasks(context_padding, question_padding) if self.training: - answer_padding = (answer_indices.data == pad_idx)[:, :-1] + answer_padding = (answer_indices.data == decoder_pad_idx)[:, :-1] if self.args.pretrained_decoder_lm: # note that pretrained_decoder_embeddings is time first answer_pretrained_numerical = [ - [self.pretrained_decoder_vocab_stoi[sentence[time]] for sentence in answer_tokens] for time in range(len(answer_tokens[0])) + [self.pretrained_decoder_vocab_stoi[sentence[time]] for sentence in answer_tokens] for time in + range(len(answer_tokens[0])) ] - answer_pretrained_numerical = torch.tensor(answer_pretrained_numerical, dtype=torch.long, device=self.device) + answer_pretrained_numerical = torch.tensor(answer_pretrained_numerical, dtype=torch.long, + device=self.device) with torch.no_grad(): answer_embedded, _ = self.pretrained_decoder_embeddings.encode(answer_pretrained_numerical) @@ -181,68 +205,68 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): answer_embedded = self.pretrained_decoder_embedding_projection(answer_embedded) else: answer_embedded = self.decoder_embeddings(answer) - self_attended_decoded = self.self_attentive_decoder(answer_embedded[:, :-1].contiguous(), self_attended_context, context_padding=context_padding, answer_padding=answer_padding, positional_encodings=True) - decoder_outputs = self.dual_ptr_rnn_decoder(self_attended_decoded, - final_context, final_question, hidden=context_rnn_state) + self_attended_decoded = self.self_attentive_decoder(answer_embedded[:, :-1].contiguous(), + self_attended_context, context_padding=context_padding, + answer_padding=answer_padding, + positional_encodings=True) + decoder_outputs = self.dual_ptr_rnn_decoder(self_attended_decoded, + final_context, final_question, hidden=context_rnn_state) rnn_output, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switch, context_question_switch, rnn_state = decoder_outputs - probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch, - context_attention, question_attention, - context_indices, question_indices, - decoder_vocab) + probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch, + context_attention, question_attention, + context_indices, question_indices, + decoder_vocab) - - probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=pad_idx) + probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=decoder_pad_idx) loss = F.nll_loss(probs.log(), targets) return loss, None else: - return None, self.greedy(self_attended_context, final_context, final_question, - context_indices, question_indices, - decoder_vocab, rnn_state=context_rnn_state).data - - def reshape_rnn_state(self, h): - return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \ - .transpose(1, 2).contiguous() \ - .view(h.size(0) // 2, h.size(1), h.size(2) * 2).contiguous() + return None, self.greedy(self_attended_context, final_context, final_question, + context_indices, question_indices, + decoder_vocab, rnn_state=context_rnn_state).data - def probs(self, generator, outputs, vocab_pointer_switches, context_question_switches, - context_attention, question_attention, - context_indices, question_indices, - decoder_vocab): + def probs(self, generator, outputs, vocab_pointer_switches, context_question_switches, + context_attention, question_attention, + context_indices, question_indices, + decoder_vocab): size = list(outputs.size()) size[-1] = self.generative_vocab_size scores = generator(outputs.view(-1, outputs.size(-1))).view(size) - p_vocab = F.softmax(scores, dim=scores.dim()-1) + p_vocab = F.softmax(scores, dim=scores.dim() - 1) scaled_p_vocab = vocab_pointer_switches.expand_as(p_vocab) * p_vocab effective_vocab_size = len(decoder_vocab) if self.generative_vocab_size < effective_vocab_size: size[-1] = effective_vocab_size - self.generative_vocab_size buff = scaled_p_vocab.new_full(size, EPSILON) - scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim()-1) + scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim() - 1) # p_context_ptr - scaled_p_vocab.scatter_add_(scaled_p_vocab.dim()-1, context_indices.unsqueeze(1).expand_as(context_attention), - (context_question_switches * (1 - vocab_pointer_switches)).expand_as(context_attention) * context_attention) + scaled_p_vocab.scatter_add_(scaled_p_vocab.dim() - 1, context_indices.unsqueeze(1).expand_as(context_attention), + (context_question_switches * (1 - vocab_pointer_switches)).expand_as( + context_attention) * context_attention) # p_question_ptr - scaled_p_vocab.scatter_add_(scaled_p_vocab.dim()-1, question_indices.unsqueeze(1).expand_as(question_attention), - ((1 - context_question_switches) * (1 - vocab_pointer_switches)).expand_as(question_attention) * question_attention) + scaled_p_vocab.scatter_add_(scaled_p_vocab.dim() - 1, + question_indices.unsqueeze(1).expand_as(question_attention), + ((1 - context_question_switches) * (1 - vocab_pointer_switches)).expand_as( + question_attention) * question_attention) return scaled_p_vocab - - def greedy(self, self_attended_context, context, question, context_indices, question_indices, decoder_vocab, rnn_state=None): + def greedy(self, self_attended_context, context, question, context_indices, question_indices, decoder_vocab, + rnn_state=None): B, TC, C = context.size() T = self.args.max_output_length outs = context.new_full((B, T), self.field.decoder_vocab.stoi[self.field.pad_token], dtype=torch.long) hiddens = [self_attended_context[0].new_zeros((B, T, C)) for l in range(len(self.self_attentive_decoder.layers) + 1)] hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0]) - eos_yet = context.new_zeros((B, )).byte() + eos_yet = context.new_zeros((B,)).byte() pretrained_lm_hidden = None if self.args.pretrained_decoder_lm: @@ -251,17 +275,21 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): for t in range(T): if t == 0: if self.args.pretrained_decoder_lm: - init_token = self_attended_context[-1].new_full((1, B), self.pretrained_decoder_vocab_stoi[''], dtype=torch.long) - + init_token = self_attended_context[-1].new_full((1, B), + self.pretrained_decoder_vocab_stoi[''], + dtype=torch.long) + # note that pretrained_decoder_embeddings is time first - embedding, pretrained_lm_hidden = self.pretrained_decoder_embeddings.encode(init_token, pretrained_lm_hidden) + embedding, pretrained_lm_hidden = self.pretrained_decoder_embeddings.encode(init_token, + pretrained_lm_hidden) embedding.transpose_(0, 1) if self.pretrained_decoder_embedding_projection is not None: embedding = self.pretrained_decoder_embedding_projection(embedding) else: - init_token = self_attended_context[-1].new_full((B, 1), self.field.vocab.stoi[''], dtype=torch.long) - embedding = self.decoder_embeddings(init_token, [1]*B) + init_token = self_attended_context[-1].new_full((B, 1), self.field.vocab.stoi[''], + dtype=torch.long) + embedding = self.decoder_embeddings(init_token, [1] * B) else: if self.args.pretrained_decoder_lm: current_token = [self.field.vocab.itos[x] for x in outs[:, t - 1]] @@ -269,7 +297,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): dtype=torch.long, device=self.device, requires_grad=False) embedding, pretrained_lm_hidden = self.pretrained_decoder_embeddings.encode(current_token_id, pretrained_lm_hidden) - + # note that pretrained_decoder_embeddings is time first embedding.transpose_(0, 1) @@ -277,23 +305,26 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): embedding = self.pretrained_decoder_embedding_projection(embedding) else: current_token_id = outs[:, t - 1].unsqueeze(1) - embedding = self.decoder_embeddings(current_token_id, [1]*B) + embedding = self.decoder_embeddings(current_token_id, [1] * B) - hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze(1) + hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze( + 1) for l in range(len(self.self_attentive_decoder.layers)): hiddens[l + 1][:, t] = self.self_attentive_decoder.layers[l].feedforward( self.self_attentive_decoder.layers[l].attention( - self.self_attentive_decoder.layers[l].selfattn(hiddens[l][:, t], hiddens[l][:, :t + 1], hiddens[l][:, :t + 1]) - , self_attended_context[l], self_attended_context[l])) + self.self_attentive_decoder.layers[l].selfattn(hiddens[l][:, t], hiddens[l][:, :t + 1], + hiddens[l][:, :t + 1]) + , self_attended_context[l], self_attended_context[l])) decoder_outputs = self.dual_ptr_rnn_decoder(hiddens[-1][:, t].unsqueeze(1), - context, question, - context_alignment=context_alignment, question_alignment=question_alignment, - hidden=rnn_state, output=rnn_output) + context, question, + context_alignment=context_alignment, + question_alignment=question_alignment, + hidden=rnn_state, output=rnn_output) rnn_output, context_attention, question_attention, context_alignment, question_alignment, vocab_pointer_switch, context_question_switch, rnn_state = decoder_outputs - probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch, - context_attention, question_attention, - context_indices, question_indices, - decoder_vocab) + probs = self.probs(self.out, rnn_output, vocab_pointer_switch, context_question_switch, + context_attention, question_attention, + context_indices, question_indices, + decoder_vocab) pred_probs, preds = probs.max(-1) preds = preds.squeeze(1) eos_yet = eos_yet | (preds == self.field.decoder_vocab.stoi[self.field.eos_token]).byte() @@ -303,6 +334,31 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): return outs +class MultitaskQuestionAnsweringNetwork(nn.Module): + + def __init__(self, field, args): + super().__init__() + self.field = field + self.args = args + self.pad_idx = self.field.vocab.stoi[self.field.pad_token] + self.device = set_seed(args) + + self.encoder = MQANEncoder(field, args) + self.decoder = MQANDecoder(field, args) + + + def set_embeddings(self, embeddings): + self.encoder.set_embeddings(embeddings) + self.decoder.set_embeddings(embeddings) + + def forward(self, batch, iteration): + self_attended_context, final_context, context_rnn_state, final_question, question_rnn_state = self.encoder(batch) + + loss, predictions = self.decoder(batch, self_attended_context, final_context, context_rnn_state, + final_question, question_rnn_state) + + return loss, predictions + class DualPtrRNNDecoder(nn.Module):