From 6b4e29ae043c4fd5aad47af8b3e2b425652a52f9 Mon Sep 17 00:00:00 2001 From: Sina Date: Tue, 3 Mar 2020 00:41:18 -0800 Subject: [PATCH] bug fixes --- genienlp/models/mqan_decoder.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/genienlp/models/mqan_decoder.py b/genienlp/models/mqan_decoder.py index 817c6c97..fad1b754 100644 --- a/genienlp/models/mqan_decoder.py +++ b/genienlp/models/mqan_decoder.py @@ -182,7 +182,7 @@ class MQANDecoder(nn.Module): decoder_wrapper = MQANDecoderWrapper(self_attended_context, context, context_padding, question, question_padding, context_indices, question_indices, decoder_vocab, rnn_state, batch_size, max_decoder_time, self, num_beams=self.args.num_beams) - + if self.args.num_beams > 1: outputs = self._decode_beam_search( input_ids=input_ids, @@ -204,25 +204,27 @@ class MQANDecoder(nn.Module): outputs = self._decode_greedy( input_ids=input_ids, max_length=self.args.max_output_length, + pad_token_id=decoder_vocab.pad_idx, eos_token_id=decoder_vocab.eos_idx, batch_size=batch_size, decoder_wrapper=decoder_wrapper, ) - # print('outputs = ', outputs) + # print('outputs = ', outputs.shape) return outputs def _decode_greedy( self, input_ids, max_length, + pad_token_id, eos_token_id, batch_size, decoder_wrapper ): - outs = input_ids.new_full((batch_size, max_length), self.pad_idx, dtype=torch.long) + + outs = input_ids.new_full((batch_size, max_length), pad_token_id, dtype=torch.long) eos_yet = input_ids.new_zeros((batch_size,)).byte() - for t in range(max_length): probs = decoder_wrapper.next_token_probs(input_ids[:, -1].unsqueeze(-1)) pred_probs, preds = probs.max(-1) @@ -230,8 +232,7 @@ class MQANDecoder(nn.Module): outs[:, t] = preds.cpu().apply_(self.map_to_full) if eos_yet.all(): break - preds = preds.unsqueeze(1) - input_ids = torch.cat((input_ids, preds), dim=1) + input_ids = torch.cat((input_ids, outs[:, t].unsqueeze(1)), dim=1) return outs def _decode_beam_search( @@ -272,10 +273,8 @@ class MQANDecoder(nn.Module): done = [False for _ in range(batch_size)] while cur_len < max_length: - # print('input_ids = ', input_ids[:, -1].unsqueeze(-1)) - # next_token_probs outputs a normalized probability distribution instead of logits - scores = torch.log(torch.log(decoder_wrapper.next_token_probs(input_ids[:, -1].unsqueeze(-1)))) # (batch_size * num_beams, vocab_size) + scores = torch.log(decoder_wrapper.next_token_probs(input_ids[:, -1].unsqueeze(-1))) # (batch_size * num_beams, vocab_size) # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) if repetition_penalty != 1.0: