From 618e1c8061753e767e7ae628cf55098b8fa6ad55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 9 Nov 2023 00:41:31 +0100 Subject: [PATCH] Provide the default vocab size for the transformer demo model (#18963) --- src/lightning/pytorch/demos/transformer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py index 6c2ad6defc..6a389fe008 100644 --- a/src/lightning/pytorch/demos/transformer.py +++ b/src/lightning/pytorch/demos/transformer.py @@ -26,7 +26,13 @@ if hasattr(MultiheadAttention, "_reset_parameters") and not hasattr(MultiheadAtt class Transformer(nn.Module): def __init__( - self, vocab_size: int, ninp: int = 200, nhead: int = 2, nhid: int = 200, nlayers: int = 2, dropout: float = 0.2 + self, + vocab_size: int = 33278, # default for WikiText2 + ninp: int = 200, + nhead: int = 2, + nhid: int = 200, + nlayers: int = 2, + dropout: float = 0.2, ) -> None: super().__init__() self.pos_encoder = PositionalEncoding(ninp, dropout)