Update embeddings.py

This commit is contained in:
William Falcon 2019-04-03 11:21:16 -04:00 committed by GitHub
parent a01e2ade25
commit bca1c4b594
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 6 deletions

View File

@ -13,11 +13,7 @@ class PretrainedEmbedding(torch.nn.Embedding):
>>> emb = PretrainedEmbedding(embedding_path='glove.840B.300d.txt',embedding_dim=300, task_vocab={'hello': 1, 'world': 2}) >>> emb = PretrainedEmbedding(embedding_path='glove.840B.300d.txt',embedding_dim=300, task_vocab={'hello': 1, 'world': 2})
>>> data = torch.Tensor([[0, 1], [0, 2]]).long() >>> data = torch.Tensor([[0, 1], [0, 2]]).long()
>>> embedded = emb(data) >>> embedded = emb(data)
tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.2523, 0.1018, -0.6748, ..., 0.1787, -0.5192, 0.3359]],
[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[-0.0067, 0.2224, 0.2771, ..., 0.0594, 0.0014, 0.0987]]])
:param embedding_path: :param embedding_path:
@ -37,7 +33,8 @@ class PretrainedEmbedding(torch.nn.Embedding):
self.weight = new_emb.weight self.weight = new_emb.weight
# apply freeze # apply freeze
self.weight.requires_grad = not freeze should_freeze = not freeze
self.weight.requires_grad = should_freeze
def __load_task_specific_embeddings(self, vocab_words, embedding_path, emb_dim, freeze): def __load_task_specific_embeddings(self, vocab_words, embedding_path, emb_dim, freeze):
""" """