diff --git a/pytorch_lightning/utils/embeddings.py b/pytorch_lightning/utils/embeddings.py index 507848b242..4e96c61b12 100644 --- a/pytorch_lightning/utils/embeddings.py +++ b/pytorch_lightning/utils/embeddings.py @@ -13,11 +13,7 @@ class PretrainedEmbedding(torch.nn.Embedding): >>> emb = PretrainedEmbedding(embedding_path='glove.840B.300d.txt',embedding_dim=300, task_vocab={'hello': 1, 'world': 2}) >>> data = torch.Tensor([[0, 1], [0, 2]]).long() >>> embedded = emb(data) - tensor([[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], - [ 0.2523, 0.1018, -0.6748, ..., 0.1787, -0.5192, 0.3359]], - - [[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], - [-0.0067, 0.2224, 0.2771, ..., 0.0594, 0.0014, 0.0987]]]) + :param embedding_path: @@ -37,7 +33,8 @@ class PretrainedEmbedding(torch.nn.Embedding): self.weight = new_emb.weight # apply freeze - self.weight.requires_grad = not freeze + should_freeze = not freeze + self.weight.requires_grad = should_freeze def __load_task_specific_embeddings(self, vocab_words, embedding_path, emb_dim, freeze): """