lightning/pytorch_lightning/utils/embeddings.py

105 lines
3.4 KiB
Python

import torch
import numpy as np
from copy import deepcopy
class PretrainedEmbedding(torch.nn.Embedding):
def __init__(self, embedding_path, embedding_dim, task_vocab, freeze=True, *args, **kwargs):
"""
Loads a prebuilt pytorch embedding from any embedding formated file.
Padding=0 by default.
>>> emb = PretrainedEmbedding(embedding_path='glove.840B.300d.txt',embedding_dim=300, task_vocab={'hello': 1, 'world': 2})
>>> data = torch.Tensor([[0, 1], [0, 2]]).long()
>>> embedded = emb(data)
:param embedding_path:
:param emb_dim:
:param task_vocab:
:param freeze:
:return:
"""
# count the vocab
self.vocab_size = max(task_vocab.values()) + 1
super(PretrainedEmbedding, self).__init__(self.vocab_size, embedding_dim, padding_idx=0, *args, **kwargs)
# load pretrained embeddings
new_emb = self.__load_task_specific_embeddings(deepcopy(task_vocab), embedding_path, embedding_dim, freeze)
# transfer weights
self.weight = new_emb.weight
# apply freeze
should_freeze = not freeze
self.weight.requires_grad = should_freeze
def __load_task_specific_embeddings(self, vocab_words, embedding_path, emb_dim, freeze):
"""
Iterates embedding file to only pull out task specific embeddings
:param vocab_words:
:param embedding_path:
:param emb_dim:
:param freeze:
:return:
"""
# holds final embeddings for relevant words
embeddings = np.zeros(shape=(self.vocab_size, emb_dim))
# load embedding line by line and extract relevant embeddings
with open(embedding_path, encoding='utf-8') as f:
for line in f:
tokens = line.split(' ')
word = tokens[0]
embedding = tokens[1:]
embedding[-1] = embedding[-1][:-1] # remove last new line
if word in vocab_words:
vocab_word_i = vocab_words[word]
# skip words that try to overwrite pad idx
if vocab_word_i == 0:
del vocab_words[word]
continue
emb_vals = np.asarray([float(x) for x in embedding])
embeddings[vocab_word_i] = emb_vals
# remove vocab word to early terminate
del vocab_words[word]
# early break
if len(vocab_words) == 0:
break
# add random vectors for the non-pretrained words
# these are vocab words NOT found in the pretrained embeddings
for w, i in vocab_words.items():
# skip words that try to overwrite pad idx
if i == 0:
continue
embedding = np.random.normal(size=emb_dim)
embeddings[i] = embedding
# turn into pt embedding
embeddings = torch.FloatTensor(embeddings)
embeddings = torch.nn.Embedding.from_pretrained(embeddings, freeze=freeze)
return embeddings
if __name__ == '__main__':
emb = PretrainedEmbedding(
embedding_path='/Users/waf/Developer',
embedding_dim=300,
task_vocab={'hello': 1, 'world': 2}
)
data = torch.Tensor([[0, 1], [0, 2]]).long()
embedded = emb(data)
print(embedded)