add assertion to avoid nan in std
This commit is contained in:
parent
74401aa03f
commit
70fc430a42
|
@ -111,7 +111,7 @@ class Seq2Seq(torch.nn.Module):
|
|||
context_rnn_state = context_rnn_state.view(batch_size, -1)
|
||||
|
||||
if self.args.encoder_loss_type == 'mean':
|
||||
# element-wise mean of encoder loss #https://www.aclweb.org/anthology/W18-3023.pdf
|
||||
# element-wise mean of encoder loss https://www.aclweb.org/anthology/W18-3023.pdf
|
||||
context_value = torch.mean(context_rnn_state, dim=-1)
|
||||
elif self.args.encoder_loss_type == 'sum':
|
||||
context_value = torch.sum(context_rnn_state, dim=-1)
|
||||
|
@ -120,6 +120,7 @@ class Seq2Seq(torch.nn.Module):
|
|||
for i in range(0, batch_size, groups):
|
||||
indices = [j for j in range(i, i+groups)]
|
||||
groups_vals = context_value[indices]
|
||||
assert len(groups_vals) > 1
|
||||
encoder_loss += torch.std(groups_vals).item()
|
||||
|
||||
return encoder_loss
|
||||
|
|
Loading…
Reference in New Issue