bug fixes

This commit is contained in:
Sina 2020-03-03 00:41:18 -08:00
parent 06131f12dc
commit 6b4e29ae04
1 changed files with 8 additions and 9 deletions

View File

@ -204,25 +204,27 @@ class MQANDecoder(nn.Module):
outputs = self._decode_greedy(
input_ids=input_ids,
max_length=self.args.max_output_length,
pad_token_id=decoder_vocab.pad_idx,
eos_token_id=decoder_vocab.eos_idx,
batch_size=batch_size,
decoder_wrapper=decoder_wrapper,
)
# print('outputs = ', outputs)
# print('outputs = ', outputs.shape)
return outputs
def _decode_greedy(
self,
input_ids,
max_length,
pad_token_id,
eos_token_id,
batch_size,
decoder_wrapper
):
outs = input_ids.new_full((batch_size, max_length), self.pad_idx, dtype=torch.long)
eos_yet = input_ids.new_zeros((batch_size,)).byte()
outs = input_ids.new_full((batch_size, max_length), pad_token_id, dtype=torch.long)
eos_yet = input_ids.new_zeros((batch_size,)).byte()
for t in range(max_length):
probs = decoder_wrapper.next_token_probs(input_ids[:, -1].unsqueeze(-1))
pred_probs, preds = probs.max(-1)
@ -230,8 +232,7 @@ class MQANDecoder(nn.Module):
outs[:, t] = preds.cpu().apply_(self.map_to_full)
if eos_yet.all():
break
preds = preds.unsqueeze(1)
input_ids = torch.cat((input_ids, preds), dim=1)
input_ids = torch.cat((input_ids, outs[:, t].unsqueeze(1)), dim=1)
return outs
def _decode_beam_search(
@ -272,10 +273,8 @@ class MQANDecoder(nn.Module):
done = [False for _ in range(batch_size)]
while cur_len < max_length:
# print('input_ids = ', input_ids[:, -1].unsqueeze(-1))
# next_token_probs outputs a normalized probability distribution instead of logits
scores = torch.log(torch.log(decoder_wrapper.next_token_probs(input_ids[:, -1].unsqueeze(-1)))) # (batch_size * num_beams, vocab_size)
scores = torch.log(decoder_wrapper.next_token_probs(input_ids[:, -1].unsqueeze(-1))) # (batch_size * num_beams, vocab_size)
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0: