bug fixes
This commit is contained in:
parent
06131f12dc
commit
6b4e29ae04
|
@ -204,25 +204,27 @@ class MQANDecoder(nn.Module):
|
||||||
outputs = self._decode_greedy(
|
outputs = self._decode_greedy(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
max_length=self.args.max_output_length,
|
max_length=self.args.max_output_length,
|
||||||
|
pad_token_id=decoder_vocab.pad_idx,
|
||||||
eos_token_id=decoder_vocab.eos_idx,
|
eos_token_id=decoder_vocab.eos_idx,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
decoder_wrapper=decoder_wrapper,
|
decoder_wrapper=decoder_wrapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
# print('outputs = ', outputs)
|
# print('outputs = ', outputs.shape)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def _decode_greedy(
|
def _decode_greedy(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
max_length,
|
max_length,
|
||||||
|
pad_token_id,
|
||||||
eos_token_id,
|
eos_token_id,
|
||||||
batch_size,
|
batch_size,
|
||||||
decoder_wrapper
|
decoder_wrapper
|
||||||
):
|
):
|
||||||
outs = input_ids.new_full((batch_size, max_length), self.pad_idx, dtype=torch.long)
|
|
||||||
eos_yet = input_ids.new_zeros((batch_size,)).byte()
|
|
||||||
|
|
||||||
|
outs = input_ids.new_full((batch_size, max_length), pad_token_id, dtype=torch.long)
|
||||||
|
eos_yet = input_ids.new_zeros((batch_size,)).byte()
|
||||||
for t in range(max_length):
|
for t in range(max_length):
|
||||||
probs = decoder_wrapper.next_token_probs(input_ids[:, -1].unsqueeze(-1))
|
probs = decoder_wrapper.next_token_probs(input_ids[:, -1].unsqueeze(-1))
|
||||||
pred_probs, preds = probs.max(-1)
|
pred_probs, preds = probs.max(-1)
|
||||||
|
@ -230,8 +232,7 @@ class MQANDecoder(nn.Module):
|
||||||
outs[:, t] = preds.cpu().apply_(self.map_to_full)
|
outs[:, t] = preds.cpu().apply_(self.map_to_full)
|
||||||
if eos_yet.all():
|
if eos_yet.all():
|
||||||
break
|
break
|
||||||
preds = preds.unsqueeze(1)
|
input_ids = torch.cat((input_ids, outs[:, t].unsqueeze(1)), dim=1)
|
||||||
input_ids = torch.cat((input_ids, preds), dim=1)
|
|
||||||
return outs
|
return outs
|
||||||
|
|
||||||
def _decode_beam_search(
|
def _decode_beam_search(
|
||||||
|
@ -272,10 +273,8 @@ class MQANDecoder(nn.Module):
|
||||||
done = [False for _ in range(batch_size)]
|
done = [False for _ in range(batch_size)]
|
||||||
|
|
||||||
while cur_len < max_length:
|
while cur_len < max_length:
|
||||||
# print('input_ids = ', input_ids[:, -1].unsqueeze(-1))
|
|
||||||
|
|
||||||
# next_token_probs outputs a normalized probability distribution instead of logits
|
# next_token_probs outputs a normalized probability distribution instead of logits
|
||||||
scores = torch.log(torch.log(decoder_wrapper.next_token_probs(input_ids[:, -1].unsqueeze(-1)))) # (batch_size * num_beams, vocab_size)
|
scores = torch.log(decoder_wrapper.next_token_probs(input_ids[:, -1].unsqueeze(-1))) # (batch_size * num_beams, vocab_size)
|
||||||
|
|
||||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||||
if repetition_penalty != 1.0:
|
if repetition_penalty != 1.0:
|
||||||
|
|
Loading…
Reference in New Issue