diff --git a/genienlp/models/identity_encoder.py b/genienlp/models/identity_encoder.py index 76f9f0cb..dc17c3ab 100644 --- a/genienlp/models/identity_encoder.py +++ b/genienlp/models/identity_encoder.py @@ -54,7 +54,7 @@ class IdentityEncoder(nn.Module): if self.args.rnn_layers > 0 and self.args.rnn_zero_state == 'average': self.pool = LinearFeedforward(args.dimension, args.dimension, 2 * args.rnn_dimension * args.rnn_layers, dropout=args.dropout_ratio) - self.norm = LayerNorm(2 * args.rnn_dimension) + self.norm = LayerNorm(2 * args.rnn_dimension * args.rnn_layers) else: self.pool = None self.norm = None