Contextualized Word Vectors (CoVe; McCann et al. 2017)

This commit is contained in:
Bryan Marcus McCann 2018-08-28 03:14:41 +00:00
parent 3c8d5e3b7e
commit 8e85a13c87
5 changed files with 21 additions and 7 deletions

View File

@ -62,6 +62,7 @@ def parse():
parser.add_argument('--transformer_heads', default=3, type=int, help='number of heads for transformer modules')
parser.add_argument('--dropout_ratio', default=0.2, type=float, help='dropout for the model')
parser.add_argument('--no_transformer_lr', action='store_false', dest='transformer_lr', help='turns off the transformer learning rate strategy')
parser.add_argument('--cove', action='store_true', help='whether to use contextualized word vectors (McCann et al. 2017)')
parser.add_argument('--warmup', default=800, type=int, help='warmup for learning rate')
parser.add_argument('--grad_clip', default=1.0, type=float, help='gradient clipping')

View File

@ -61,6 +61,7 @@ RUN apt-get install --yes \
python-lxml
# WikISQL evaluation
RUN pip install -e git+git://github.com/salesforce/cove.git#egg=cove
RUN pip install records
RUN pip install babel
RUN pip install tabulate

View File

@ -320,23 +320,25 @@ class Feedforward(nn.Module):
class Embedding(nn.Module):
def __init__(self, field, trained_dimension, dropout=0.0):
def __init__(self, field, trained_dimension, dropout=0.0, project=True):
super().__init__()
self.field = field
self.project = project
dimension = 0
pretrained_dimension = field.vocab.vectors.size(-1)
self.pretrained_embeddings = [nn.Embedding(len(field.vocab), pretrained_dimension)]
self.pretrained_embeddings[0].weight.data = field.vocab.vectors
self.pretrained_embeddings[0].weight.requires_grad = False
dimension += pretrained_dimension
self.projection = Feedforward(dimension, trained_dimension)
if self.project:
self.projection = Feedforward(dimension, trained_dimension)
dimension = trained_dimension
self.dropout = nn.Dropout(0.2)
self.dimension = dimension
def forward(self, x, lengths=None):
pretrained_embeddings = self.pretrained_embeddings[0](x.cpu()).cuda().detach()
return self.projection(pretrained_embeddings)
return self.projection(pretrained_embeddings) if self.project else pretrained_embeddings
def set_embeddings(self, w):
self.pretrained_embeddings[0].weight.data = w

View File

@ -7,6 +7,8 @@ from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
from cove import MTLSTM
from .common import positional_encodings_like, INF, EPSILON, TransformerEncoder, TransformerDecoder, PackedLSTM, LSTMDecoderAttention, LSTMDecoder, Embedding, Feedforward, mask
@ -19,11 +21,15 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
self.pad_idx = self.field.vocab.stoi[self.field.pad_token]
self.encoder_embeddings = Embedding(field, args.dimension,
dropout=args.dropout_ratio)
dropout=args.dropout_ratio, project=not args.cove)
self.decoder_embeddings = Embedding(field, args.dimension,
dropout=args.dropout_ratio)
dropout=args.dropout_ratio, project=True)
if self.args.cove:
self.cove = MTLSTM(model_cache=args.embeddings)
self.project_cove = Feedforward(1000, args.dimension)
self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension,
self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension,
batch_first=True, dropout=args.dropout_ratio, bidirectional=True, num_layers=1)
self.coattention = CoattentiveLayer(args.dimension, dropout=0.3)
dim = 2*args.dimension + args.dimension + args.dimension
@ -69,6 +75,9 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
context_embedded = self.encoder_embeddings(context)
question_embedded = self.encoder_embeddings(question)
if self.args.cove:
context_embedded = self.project_cove(torch.cat([self.cove(context_embedded[:, :, -300:], context_lengths), context_embedded], -1).detach())
question_embedded = self.project_cove(torch.cat([self.cove(question_embedded[:, :, -300:], question_lengths), question_embedded], -1).detach())
context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0]
question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0]

View File

@ -83,11 +83,12 @@ def prepare_data(args, field, logger):
vocab_sets.extend(split)
if args.load is None:
logger.info(f'Building vocabulary')
logger.info(f'Getting pretrained word vectors')
char_vectors = torchtext.vocab.CharNGram(cache=args.embeddings)
glove_vectors = torchtext.vocab.GloVe(cache=args.embeddings)
vectors = [char_vectors, glove_vectors]
vocab_sets = (train_sets + val_sets) if len(vocab_sets) == 0 else vocab_sets
logger.info(f'Building vocabulary')
FIELD.build_vocab(*vocab_sets, max_size=args.max_effective_vocab, vectors=vectors)
FIELD.decoder_itos = FIELD.vocab.itos[:args.max_generative_vocab]