diff --git a/genienlp/models/general_seq2seq.py b/genienlp/models/general_seq2seq.py index dc815557..b49e4494 100644 --- a/genienlp/models/general_seq2seq.py +++ b/genienlp/models/general_seq2seq.py @@ -111,7 +111,7 @@ class Seq2Seq(torch.nn.Module): context_rnn_state = context_rnn_state.view(batch_size, -1) if self.args.encoder_loss_type == 'mean': - # element-wise mean of encoder loss #https://www.aclweb.org/anthology/W18-3023.pdf + # element-wise mean of encoder loss https://www.aclweb.org/anthology/W18-3023.pdf context_value = torch.mean(context_rnn_state, dim=-1) elif self.args.encoder_loss_type == 'sum': context_value = torch.sum(context_rnn_state, dim=-1) @@ -120,6 +120,7 @@ class Seq2Seq(torch.nn.Module): for i in range(0, batch_size, groups): indices = [j for j in range(i, i+groups)] groups_vals = context_value[indices] + assert len(groups_vals) > 1 encoder_loss += torch.std(groups_vals).item() return encoder_loss