diff --git a/arguments.py b/arguments.py index efea3d00..fc70fd47 100644 --- a/arguments.py +++ b/arguments.py @@ -62,6 +62,7 @@ def parse(): parser.add_argument('--transformer_heads', default=3, type=int, help='number of heads for transformer modules') parser.add_argument('--dropout_ratio', default=0.2, type=float, help='dropout for the model') parser.add_argument('--no_transformer_lr', action='store_false', dest='transformer_lr', help='turns off the transformer learning rate strategy') + parser.add_argument('--cove', action='store_true', help='whether to use contextualized word vectors (McCann et al. 2017)') parser.add_argument('--warmup', default=800, type=int, help='warmup for learning rate') parser.add_argument('--grad_clip', default=1.0, type=float, help='gradient clipping') diff --git a/dockerfiles/Dockerfile b/dockerfiles/Dockerfile index 0864b498..12731daf 100644 --- a/dockerfiles/Dockerfile +++ b/dockerfiles/Dockerfile @@ -61,6 +61,7 @@ RUN apt-get install --yes \ python-lxml # WikISQL evaluation +RUN pip install -e git+git://github.com/salesforce/cove.git#egg=cove RUN pip install records RUN pip install babel RUN pip install tabulate diff --git a/models/common.py b/models/common.py index 0292385c..1a0345e6 100644 --- a/models/common.py +++ b/models/common.py @@ -320,23 +320,25 @@ class Feedforward(nn.Module): class Embedding(nn.Module): - def __init__(self, field, trained_dimension, dropout=0.0): + def __init__(self, field, trained_dimension, dropout=0.0, project=True): super().__init__() self.field = field + self.project = project dimension = 0 pretrained_dimension = field.vocab.vectors.size(-1) self.pretrained_embeddings = [nn.Embedding(len(field.vocab), pretrained_dimension)] self.pretrained_embeddings[0].weight.data = field.vocab.vectors self.pretrained_embeddings[0].weight.requires_grad = False dimension += pretrained_dimension - self.projection = Feedforward(dimension, trained_dimension) + if self.project: + self.projection = Feedforward(dimension, trained_dimension) dimension = trained_dimension self.dropout = nn.Dropout(0.2) self.dimension = dimension def forward(self, x, lengths=None): pretrained_embeddings = self.pretrained_embeddings[0](x.cpu()).cuda().detach() - return self.projection(pretrained_embeddings) + return self.projection(pretrained_embeddings) if self.project else pretrained_embeddings def set_embeddings(self, w): self.pretrained_embeddings[0].weight.data = w diff --git a/models/multitask_question_answering_network.py b/models/multitask_question_answering_network.py index 765a376d..b7c74721 100644 --- a/models/multitask_question_answering_network.py +++ b/models/multitask_question_answering_network.py @@ -7,6 +7,8 @@ from torch import nn from torch.nn import functional as F from torch.autograd import Variable +from cove import MTLSTM + from .common import positional_encodings_like, INF, EPSILON, TransformerEncoder, TransformerDecoder, PackedLSTM, LSTMDecoderAttention, LSTMDecoder, Embedding, Feedforward, mask @@ -19,11 +21,15 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): self.pad_idx = self.field.vocab.stoi[self.field.pad_token] self.encoder_embeddings = Embedding(field, args.dimension, - dropout=args.dropout_ratio) + dropout=args.dropout_ratio, project=not args.cove) self.decoder_embeddings = Embedding(field, args.dimension, - dropout=args.dropout_ratio) + dropout=args.dropout_ratio, project=True) + + if self.args.cove: + self.cove = MTLSTM(model_cache=args.embeddings) + self.project_cove = Feedforward(1000, args.dimension) - self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension, + self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension, batch_first=True, dropout=args.dropout_ratio, bidirectional=True, num_layers=1) self.coattention = CoattentiveLayer(args.dimension, dropout=0.3) dim = 2*args.dimension + args.dimension + args.dimension @@ -69,6 +75,9 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): context_embedded = self.encoder_embeddings(context) question_embedded = self.encoder_embeddings(question) + if self.args.cove: + context_embedded = self.project_cove(torch.cat([self.cove(context_embedded[:, :, -300:], context_lengths), context_embedded], -1).detach()) + question_embedded = self.project_cove(torch.cat([self.cove(question_embedded[:, :, -300:], question_lengths), question_embedded], -1).detach()) context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0] question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0] diff --git a/train.py b/train.py index 331a05b3..0d64b957 100644 --- a/train.py +++ b/train.py @@ -83,11 +83,12 @@ def prepare_data(args, field, logger): vocab_sets.extend(split) if args.load is None: - logger.info(f'Building vocabulary') + logger.info(f'Getting pretrained word vectors') char_vectors = torchtext.vocab.CharNGram(cache=args.embeddings) glove_vectors = torchtext.vocab.GloVe(cache=args.embeddings) vectors = [char_vectors, glove_vectors] vocab_sets = (train_sets + val_sets) if len(vocab_sets) == 0 else vocab_sets + logger.info(f'Building vocabulary') FIELD.build_vocab(*vocab_sets, max_size=args.max_effective_vocab, vectors=vectors) FIELD.decoder_itos = FIELD.vocab.itos[:args.max_generative_vocab]