diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py index 6c2ad6defc..6a389fe008 100644 --- a/src/lightning/pytorch/demos/transformer.py +++ b/src/lightning/pytorch/demos/transformer.py @@ -26,7 +26,13 @@ if hasattr(MultiheadAttention, "_reset_parameters") and not hasattr(MultiheadAtt class Transformer(nn.Module): def __init__( - self, vocab_size: int, ninp: int = 200, nhead: int = 2, nhid: int = 200, nlayers: int = 2, dropout: float = 0.2 + self, + vocab_size: int = 33278, # default for WikiText2 + ninp: int = 200, + nhead: int = 2, + nhid: int = 200, + nlayers: int = 2, + dropout: float = 0.2, ) -> None: super().__init__() self.pos_encoder = PositionalEncoding(ninp, dropout)