Update embeddings.py
This commit is contained in:
parent
a01e2ade25
commit
bca1c4b594
|
@ -13,11 +13,7 @@ class PretrainedEmbedding(torch.nn.Embedding):
|
|||
>>> emb = PretrainedEmbedding(embedding_path='glove.840B.300d.txt',embedding_dim=300, task_vocab={'hello': 1, 'world': 2})
|
||||
>>> data = torch.Tensor([[0, 1], [0, 2]]).long()
|
||||
>>> embedded = emb(data)
|
||||
tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
|
||||
[ 0.2523, 0.1018, -0.6748, ..., 0.1787, -0.5192, 0.3359]],
|
||||
|
||||
[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
|
||||
[-0.0067, 0.2224, 0.2771, ..., 0.0594, 0.0014, 0.0987]]])
|
||||
|
||||
|
||||
:param embedding_path:
|
||||
|
@ -37,7 +33,8 @@ class PretrainedEmbedding(torch.nn.Embedding):
|
|||
self.weight = new_emb.weight
|
||||
|
||||
# apply freeze
|
||||
self.weight.requires_grad = not freeze
|
||||
should_freeze = not freeze
|
||||
self.weight.requires_grad = should_freeze
|
||||
|
||||
def __load_task_specific_embeddings(self, vocab_words, embedding_path, emb_dim, freeze):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue