Provide the default vocab size for the transformer demo model (#18963)
This commit is contained in:
parent
57cc01bc0a
commit
618e1c8061
|
@ -26,7 +26,13 @@ if hasattr(MultiheadAttention, "_reset_parameters") and not hasattr(MultiheadAtt
|
|||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self, vocab_size: int, ninp: int = 200, nhead: int = 2, nhid: int = 200, nlayers: int = 2, dropout: float = 0.2
|
||||
self,
|
||||
vocab_size: int = 33278, # default for WikiText2
|
||||
ninp: int = 200,
|
||||
nhead: int = 2,
|
||||
nhid: int = 200,
|
||||
nlayers: int = 2,
|
||||
dropout: float = 0.2,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.pos_encoder = PositionalEncoding(ninp, dropout)
|
||||
|
|
Loading…
Reference in New Issue