more dynamic elmo use cases

This commit is contained in:
Bryan Marcus McCann 2018-11-30 00:19:13 +00:00
parent 37003d2a0a
commit 9a4493b655
2 changed files with 29 additions and 18 deletions

View File

@ -65,7 +65,8 @@ def parse():
parser.add_argument('--no_transformer_lr', action='store_false', dest='transformer_lr', help='turns off the transformer learning rate strategy') parser.add_argument('--no_transformer_lr', action='store_false', dest='transformer_lr', help='turns off the transformer learning rate strategy')
parser.add_argument('--cove', action='store_true', help='whether to use contextualized word vectors (McCann et al. 2017)') parser.add_argument('--cove', action='store_true', help='whether to use contextualized word vectors (McCann et al. 2017)')
parser.add_argument('--intermediate_cove', action='store_true', help='whether to use the intermediate layers of contextualized word vectors (McCann et al. 2017)') parser.add_argument('--intermediate_cove', action='store_true', help='whether to use the intermediate layers of contextualized word vectors (McCann et al. 2017)')
parser.add_argument('--elmo', action='store_true', help='whether to use deep contextualized word vectors (Peters et al. 2018)') parser.add_argument('--elmo', default=[-1], nargs='+', type=int, help='which layer(s) (0, 1, or 2) of ELMo (Peters et al. 2018) to use; -1 for none ')
parser.add_argument('--no_glove_and_char', action='store_false', dest='glove_and_char', help='turn off GloVe and CharNGram embeddings')
parser.add_argument('--warmup', default=800, type=int, help='warmup for learning rate') parser.add_argument('--warmup', default=800, type=int, help='warmup for learning rate')
parser.add_argument('--grad_clip', default=1.0, type=float, help='gradient clipping') parser.add_argument('--grad_clip', default=1.0, type=float, help='gradient clipping')
@ -119,7 +120,7 @@ def parse():
f'{args.world_size}g', f'{args.world_size}g',
args.commit[:7]) args.commit[:7])
args.dist_sync_file = os.path.join(args.log_dir, 'distributed_sync_file') args.dist_sync_file = os.path.join(args.log_dir, 'distributed_sync_file')
save_args(args) save_args(args)
return args return args

View File

@ -24,16 +24,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
def dp(args): def dp(args):
return args.dropout_ratio if args.rnn_layers > 1 else 0. return args.dropout_ratio if args.rnn_layers > 1 else 0.
if self.args.elmo: if self.args.glove_and_char:
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
self.elmo = Elmo(options_file, weight_file, 3, dropout=0.0, do_layer_norm=False)
elmo_params = get_trainable_params(self.elmo)
for p in elmo_params:
p.requires_grad = False
elmo_dim = 1024
self.project_elmo = Feedforward(elmo_dim, args.dimension)
else:
self.encoder_embeddings = Embedding(field, args.dimension, self.encoder_embeddings = Embedding(field, args.dimension,
dropout=args.dropout_ratio, project=not args.cove) dropout=args.dropout_ratio, project=not args.cove)
@ -45,6 +36,18 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
p.requires_grad = False p.requires_grad = False
cove_dim = int(args.intermediate_cove) * 600 + int(args.cove) * 600 + 400 # the last 400 is for GloVe and char n-gram embeddings cove_dim = int(args.intermediate_cove) * 600 + int(args.cove) * 600 + 400 # the last 400 is for GloVe and char n-gram embeddings
self.project_cove = Feedforward(cove_dim, args.dimension) self.project_cove = Feedforward(cove_dim, args.dimension)
if -1 not in self.args.elmo:
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
self.elmo = Elmo(options_file, weight_file, 3, dropout=0.0, do_layer_norm=False)
elmo_params = get_trainable_params(self.elmo)
for p in elmo_params:
p.requires_grad = False
elmo_dim = 1024 * len(self.args.elmo)
self.project_elmo = Feedforward(elmo_dim, args.dimension)
if self.args.glove_and_char:
self.project_embeddings = Feedforward(2 * args.dimension, args.dimension, dropout=0.0)
self.decoder_embeddings = Embedding(field, args.dimension, self.decoder_embeddings = Embedding(field, args.dimension,
dropout=args.dropout_ratio, project=True) dropout=args.dropout_ratio, project=True)
@ -93,17 +96,24 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
return limited_idx_to_full_idx[x] return limited_idx_to_full_idx[x]
self.map_to_full = map_to_full self.map_to_full = map_to_full
if self.args.elmo: if -1 not in self.args.elmo:
def elmo(z, device): def elmo(z, layers, device):
return self.elmo(batch_to_ids(z).to(device))['elmo_representations'][-1] e = self.elmo(batch_to_ids(z).to(device))['elmo_representations']
context_embedded = self.project_elmo(elmo(context_elmo, context.device).detach()) return torch.cat([e[x] for x in layers], -1)
question_embedded = self.project_elmo(elmo(question_elmo, question.device).detach()) context_elmo = self.project_elmo(elmo(context_elmo, self.args.elmo, context.device).detach())
else: question_elmo = self.project_elmo(elmo(question_elmo, self.args.elmo, question.device).detach())
if self.args.glove_and_char:
context_embedded = self.encoder_embeddings(context) context_embedded = self.encoder_embeddings(context)
question_embedded = self.encoder_embeddings(question) question_embedded = self.encoder_embeddings(question)
if self.args.cove: if self.args.cove:
context_embedded = self.project_cove(torch.cat([self.cove(context_embedded[:, :, -300:], context_lengths), context_embedded], -1).detach()) context_embedded = self.project_cove(torch.cat([self.cove(context_embedded[:, :, -300:], context_lengths), context_embedded], -1).detach())
question_embedded = self.project_cove(torch.cat([self.cove(question_embedded[:, :, -300:], question_lengths), question_embedded], -1).detach()) question_embedded = self.project_cove(torch.cat([self.cove(question_embedded[:, :, -300:], question_lengths), question_embedded], -1).detach())
if -1 not in self.args.elmo:
context_embedded = self.project_embeddings(torch.cat([context_embedded, context_elmo], -1))
question_embedded = self.project_embeddings(torch.cat([question_embedded, question_elmo], -1))
else:
context_embedded, question_embedded = context_elmo, question_elmo
context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0] context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0]
question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0] question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0]