Provide the default vocab size for the transformer demo model (#18963)

This commit is contained in:
Adrian Wälchli 2023-11-09 00:41:31 +01:00 committed by GitHub
parent 57cc01bc0a
commit 618e1c8061
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 1 deletions

View File

@ -26,7 +26,13 @@ if hasattr(MultiheadAttention, "_reset_parameters") and not hasattr(MultiheadAtt
class Transformer(nn.Module):
def __init__(
self, vocab_size: int, ninp: int = 200, nhead: int = 2, nhid: int = 200, nlayers: int = 2, dropout: float = 0.2
self,
vocab_size: int = 33278, # default for WikiText2
ninp: int = 200,
nhead: int = 2,
nhid: int = 200,
nlayers: int = 2,
dropout: float = 0.2,
) -> None:
super().__init__()
self.pos_encoder = PositionalEncoding(ninp, dropout)