diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py index ac83b5539f..58cf30cbca 100644 --- a/src/lightning/pytorch/demos/transformer.py +++ b/src/lightning/pytorch/demos/transformer.py @@ -88,7 +88,7 @@ class PositionalEncoding(nn.Module): # TODO: Could make this a `nn.Parameter` with `requires_grad=False` self.pe = self._init_pos_encoding(device=x.device) - x = x + self.pe[: x.size(0), :] + x = x + self.pe[:, x.size(1)] return self.dropout(x) def _init_pos_encoding(self, device: torch.device) -> Tensor: @@ -97,7 +97,7 @@ class PositionalEncoding(nn.Module): div_term = torch.exp(torch.arange(0, self.dim, 2, device=device).float() * (-math.log(10000.0) / self.dim)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0).transpose(0, 1) + pe = pe.unsqueeze(0) return pe