From 7038b8d41a13bda49b4dd635ad8093032da8518d Mon Sep 17 00:00:00 2001 From: Galaxy-Husky <598756381@qq.com> Date: Thu, 15 Aug 2024 16:55:13 +0800 Subject: [PATCH] fix: correct the positional encoding of Transformer in pytorch examples --- src/lightning/pytorch/demos/transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py index ac83b5539f..58cf30cbca 100644 --- a/src/lightning/pytorch/demos/transformer.py +++ b/src/lightning/pytorch/demos/transformer.py @@ -88,7 +88,7 @@ class PositionalEncoding(nn.Module): # TODO: Could make this a `nn.Parameter` with `requires_grad=False` self.pe = self._init_pos_encoding(device=x.device) - x = x + self.pe[: x.size(0), :] + x = x + self.pe[:, x.size(1)] return self.dropout(x) def _init_pos_encoding(self, device: torch.device) -> Tensor: @@ -97,7 +97,7 @@ class PositionalEncoding(nn.Module): div_term = torch.exp(torch.arange(0, self.dim, 2, device=device).float() * (-math.log(10000.0) / self.dim)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0).transpose(0, 1) + pe = pe.unsqueeze(0) return pe