Update embeddings.py
This commit is contained in:
parent
3e9f37a382
commit
a01e2ade25
|
@ -97,11 +97,11 @@ class PretrainedEmbedding(torch.nn.Embedding):
|
|||
|
||||
if __name__ == '__main__':
|
||||
emb = PretrainedEmbedding(
|
||||
embedding_path='/Users/waf/Developer/NGV/research-fermat/fermat/.vector_cache/glove.840B.300d.txt',
|
||||
embedding_path='/Users/waf/Developer',
|
||||
embedding_dim=300,
|
||||
task_vocab={'hello': 1, 'world': 2}
|
||||
)
|
||||
|
||||
data = torch.Tensor([[0, 1], [0, 2]]).long()
|
||||
embedded = emb(data)
|
||||
print(embedded)
|
||||
print(embedded)
|
||||
|
|
Loading…
Reference in New Issue