add assertion to avoid nan in std
This commit is contained in:
parent
74401aa03f
commit
70fc430a42
|
@ -111,7 +111,7 @@ class Seq2Seq(torch.nn.Module):
|
||||||
context_rnn_state = context_rnn_state.view(batch_size, -1)
|
context_rnn_state = context_rnn_state.view(batch_size, -1)
|
||||||
|
|
||||||
if self.args.encoder_loss_type == 'mean':
|
if self.args.encoder_loss_type == 'mean':
|
||||||
# element-wise mean of encoder loss #https://www.aclweb.org/anthology/W18-3023.pdf
|
# element-wise mean of encoder loss https://www.aclweb.org/anthology/W18-3023.pdf
|
||||||
context_value = torch.mean(context_rnn_state, dim=-1)
|
context_value = torch.mean(context_rnn_state, dim=-1)
|
||||||
elif self.args.encoder_loss_type == 'sum':
|
elif self.args.encoder_loss_type == 'sum':
|
||||||
context_value = torch.sum(context_rnn_state, dim=-1)
|
context_value = torch.sum(context_rnn_state, dim=-1)
|
||||||
|
@ -120,6 +120,7 @@ class Seq2Seq(torch.nn.Module):
|
||||||
for i in range(0, batch_size, groups):
|
for i in range(0, batch_size, groups):
|
||||||
indices = [j for j in range(i, i+groups)]
|
indices = [j for j in range(i, i+groups)]
|
||||||
groups_vals = context_value[indices]
|
groups_vals = context_value[indices]
|
||||||
|
assert len(groups_vals) > 1
|
||||||
encoder_loss += torch.std(groups_vals).item()
|
encoder_loss += torch.std(groups_vals).item()
|
||||||
|
|
||||||
return encoder_loss
|
return encoder_loss
|
||||||
|
|
Loading…
Reference in New Issue