add assertion to avoid nan in std

This commit is contained in:
mehrad 2020-04-01 16:47:32 -07:00
parent 74401aa03f
commit 70fc430a42
1 changed files with 2 additions and 1 deletions

View File

@ -111,7 +111,7 @@ class Seq2Seq(torch.nn.Module):
context_rnn_state = context_rnn_state.view(batch_size, -1)
if self.args.encoder_loss_type == 'mean':
# element-wise mean of encoder loss #https://www.aclweb.org/anthology/W18-3023.pdf
# element-wise mean of encoder loss https://www.aclweb.org/anthology/W18-3023.pdf
context_value = torch.mean(context_rnn_state, dim=-1)
elif self.args.encoder_loss_type == 'sum':
context_value = torch.sum(context_rnn_state, dim=-1)
@ -120,6 +120,7 @@ class Seq2Seq(torch.nn.Module):
for i in range(0, batch_size, groups):
indices = [j for j in range(i, i+groups)]
groups_vals = context_value[indices]
assert len(groups_vals) > 1
encoder_loss += torch.std(groups_vals).item()
return encoder_loss