bug fixes
This commit is contained in:
parent
06131f12dc
commit
6b4e29ae04
|
@ -182,7 +182,7 @@ class MQANDecoder(nn.Module):
|
|||
|
||||
decoder_wrapper = MQANDecoderWrapper(self_attended_context, context, context_padding, question, question_padding, context_indices, question_indices,
|
||||
decoder_vocab, rnn_state, batch_size, max_decoder_time, self, num_beams=self.args.num_beams)
|
||||
|
||||
|
||||
if self.args.num_beams > 1:
|
||||
outputs = self._decode_beam_search(
|
||||
input_ids=input_ids,
|
||||
|
@ -204,25 +204,27 @@ class MQANDecoder(nn.Module):
|
|||
outputs = self._decode_greedy(
|
||||
input_ids=input_ids,
|
||||
max_length=self.args.max_output_length,
|
||||
pad_token_id=decoder_vocab.pad_idx,
|
||||
eos_token_id=decoder_vocab.eos_idx,
|
||||
batch_size=batch_size,
|
||||
decoder_wrapper=decoder_wrapper,
|
||||
)
|
||||
|
||||
# print('outputs = ', outputs)
|
||||
# print('outputs = ', outputs.shape)
|
||||
return outputs
|
||||
|
||||
def _decode_greedy(
|
||||
self,
|
||||
input_ids,
|
||||
max_length,
|
||||
pad_token_id,
|
||||
eos_token_id,
|
||||
batch_size,
|
||||
decoder_wrapper
|
||||
):
|
||||
outs = input_ids.new_full((batch_size, max_length), self.pad_idx, dtype=torch.long)
|
||||
|
||||
outs = input_ids.new_full((batch_size, max_length), pad_token_id, dtype=torch.long)
|
||||
eos_yet = input_ids.new_zeros((batch_size,)).byte()
|
||||
|
||||
for t in range(max_length):
|
||||
probs = decoder_wrapper.next_token_probs(input_ids[:, -1].unsqueeze(-1))
|
||||
pred_probs, preds = probs.max(-1)
|
||||
|
@ -230,8 +232,7 @@ class MQANDecoder(nn.Module):
|
|||
outs[:, t] = preds.cpu().apply_(self.map_to_full)
|
||||
if eos_yet.all():
|
||||
break
|
||||
preds = preds.unsqueeze(1)
|
||||
input_ids = torch.cat((input_ids, preds), dim=1)
|
||||
input_ids = torch.cat((input_ids, outs[:, t].unsqueeze(1)), dim=1)
|
||||
return outs
|
||||
|
||||
def _decode_beam_search(
|
||||
|
@ -272,10 +273,8 @@ class MQANDecoder(nn.Module):
|
|||
done = [False for _ in range(batch_size)]
|
||||
|
||||
while cur_len < max_length:
|
||||
# print('input_ids = ', input_ids[:, -1].unsqueeze(-1))
|
||||
|
||||
# next_token_probs outputs a normalized probability distribution instead of logits
|
||||
scores = torch.log(torch.log(decoder_wrapper.next_token_probs(input_ids[:, -1].unsqueeze(-1)))) # (batch_size * num_beams, vocab_size)
|
||||
scores = torch.log(decoder_wrapper.next_token_probs(input_ids[:, -1].unsqueeze(-1))) # (batch_size * num_beams, vocab_size)
|
||||
|
||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0:
|
||||
|
|
Loading…
Reference in New Issue