fix: correct the positional encoding of Transformer in pytorch examples

This commit is contained in:
Galaxy-Husky 2024-08-15 16:55:13 +08:00 committed by Luca Antiga
parent b0aa504f80
commit 7038b8d41a
1 changed files with 2 additions and 2 deletions

View File

@ -88,7 +88,7 @@ class PositionalEncoding(nn.Module):
# TODO: Could make this a `nn.Parameter` with `requires_grad=False`
self.pe = self._init_pos_encoding(device=x.device)
x = x + self.pe[: x.size(0), :]
x = x + self.pe[:, x.size(1)]
return self.dropout(x)
def _init_pos_encoding(self, device: torch.device) -> Tensor:
@ -97,7 +97,7 @@ class PositionalEncoding(nn.Module):
div_term = torch.exp(torch.arange(0, self.dim, 2, device=device).float() * (-math.log(10000.0) / self.dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
pe = pe.unsqueeze(0)
return pe