diff --git a/arguments.py b/arguments.py index 608c1d9d..852aca4d 100644 --- a/arguments.py +++ b/arguments.py @@ -65,7 +65,8 @@ def parse(): parser.add_argument('--no_transformer_lr', action='store_false', dest='transformer_lr', help='turns off the transformer learning rate strategy') parser.add_argument('--cove', action='store_true', help='whether to use contextualized word vectors (McCann et al. 2017)') parser.add_argument('--intermediate_cove', action='store_true', help='whether to use the intermediate layers of contextualized word vectors (McCann et al. 2017)') - parser.add_argument('--elmo', action='store_true', help='whether to use deep contextualized word vectors (Peters et al. 2018)') + parser.add_argument('--elmo', default=[-1], nargs='+', type=int, help='which layer(s) (0, 1, or 2) of ELMo (Peters et al. 2018) to use; -1 for none ') + parser.add_argument('--no_glove_and_char', action='store_false', dest='glove_and_char', help='turn off GloVe and CharNGram embeddings') parser.add_argument('--warmup', default=800, type=int, help='warmup for learning rate') parser.add_argument('--grad_clip', default=1.0, type=float, help='gradient clipping') @@ -119,7 +120,7 @@ def parse(): f'{args.world_size}g', args.commit[:7]) args.dist_sync_file = os.path.join(args.log_dir, 'distributed_sync_file') - + save_args(args) return args diff --git a/models/multitask_question_answering_network.py b/models/multitask_question_answering_network.py index 1750624d..407a4ea0 100644 --- a/models/multitask_question_answering_network.py +++ b/models/multitask_question_answering_network.py @@ -24,16 +24,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): def dp(args): return args.dropout_ratio if args.rnn_layers > 1 else 0. - if self.args.elmo: - options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json" - weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5" - self.elmo = Elmo(options_file, weight_file, 3, dropout=0.0, do_layer_norm=False) - elmo_params = get_trainable_params(self.elmo) - for p in elmo_params: - p.requires_grad = False - elmo_dim = 1024 - self.project_elmo = Feedforward(elmo_dim, args.dimension) - else: + if self.args.glove_and_char: self.encoder_embeddings = Embedding(field, args.dimension, dropout=args.dropout_ratio, project=not args.cove) @@ -45,6 +36,18 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): p.requires_grad = False cove_dim = int(args.intermediate_cove) * 600 + int(args.cove) * 600 + 400 # the last 400 is for GloVe and char n-gram embeddings self.project_cove = Feedforward(cove_dim, args.dimension) + + if -1 not in self.args.elmo: + options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json" + weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5" + self.elmo = Elmo(options_file, weight_file, 3, dropout=0.0, do_layer_norm=False) + elmo_params = get_trainable_params(self.elmo) + for p in elmo_params: + p.requires_grad = False + elmo_dim = 1024 * len(self.args.elmo) + self.project_elmo = Feedforward(elmo_dim, args.dimension) + if self.args.glove_and_char: + self.project_embeddings = Feedforward(2 * args.dimension, args.dimension, dropout=0.0) self.decoder_embeddings = Embedding(field, args.dimension, dropout=args.dropout_ratio, project=True) @@ -93,17 +96,24 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): return limited_idx_to_full_idx[x] self.map_to_full = map_to_full - if self.args.elmo: - def elmo(z, device): - return self.elmo(batch_to_ids(z).to(device))['elmo_representations'][-1] - context_embedded = self.project_elmo(elmo(context_elmo, context.device).detach()) - question_embedded = self.project_elmo(elmo(question_elmo, question.device).detach()) - else: + if -1 not in self.args.elmo: + def elmo(z, layers, device): + e = self.elmo(batch_to_ids(z).to(device))['elmo_representations'] + return torch.cat([e[x] for x in layers], -1) + context_elmo = self.project_elmo(elmo(context_elmo, self.args.elmo, context.device).detach()) + question_elmo = self.project_elmo(elmo(question_elmo, self.args.elmo, question.device).detach()) + + if self.args.glove_and_char: context_embedded = self.encoder_embeddings(context) question_embedded = self.encoder_embeddings(question) if self.args.cove: context_embedded = self.project_cove(torch.cat([self.cove(context_embedded[:, :, -300:], context_lengths), context_embedded], -1).detach()) question_embedded = self.project_cove(torch.cat([self.cove(question_embedded[:, :, -300:], question_lengths), question_embedded], -1).detach()) + if -1 not in self.args.elmo: + context_embedded = self.project_embeddings(torch.cat([context_embedded, context_elmo], -1)) + question_embedded = self.project_embeddings(torch.cat([question_embedded, question_elmo], -1)) + else: + context_embedded, question_embedded = context_elmo, question_elmo context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0] question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0]