fix: correct the positional encoding of Transformer in pytorch examples
This commit is contained in:
parent
b0aa504f80
commit
7038b8d41a
|
@ -88,7 +88,7 @@ class PositionalEncoding(nn.Module):
|
|||
# TODO: Could make this a `nn.Parameter` with `requires_grad=False`
|
||||
self.pe = self._init_pos_encoding(device=x.device)
|
||||
|
||||
x = x + self.pe[: x.size(0), :]
|
||||
x = x + self.pe[:, x.size(1)]
|
||||
return self.dropout(x)
|
||||
|
||||
def _init_pos_encoding(self, device: torch.device) -> Tensor:
|
||||
|
@ -97,7 +97,7 @@ class PositionalEncoding(nn.Module):
|
|||
div_term = torch.exp(torch.arange(0, self.dim, 2, device=device).float() * (-math.log(10000.0) / self.dim))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(0, 1)
|
||||
pe = pe.unsqueeze(0)
|
||||
return pe
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue