more dynamic elmo use cases
This commit is contained in:
parent
37003d2a0a
commit
9a4493b655
|
@ -65,7 +65,8 @@ def parse():
|
|||
parser.add_argument('--no_transformer_lr', action='store_false', dest='transformer_lr', help='turns off the transformer learning rate strategy')
|
||||
parser.add_argument('--cove', action='store_true', help='whether to use contextualized word vectors (McCann et al. 2017)')
|
||||
parser.add_argument('--intermediate_cove', action='store_true', help='whether to use the intermediate layers of contextualized word vectors (McCann et al. 2017)')
|
||||
parser.add_argument('--elmo', action='store_true', help='whether to use deep contextualized word vectors (Peters et al. 2018)')
|
||||
parser.add_argument('--elmo', default=[-1], nargs='+', type=int, help='which layer(s) (0, 1, or 2) of ELMo (Peters et al. 2018) to use; -1 for none ')
|
||||
parser.add_argument('--no_glove_and_char', action='store_false', dest='glove_and_char', help='turn off GloVe and CharNGram embeddings')
|
||||
|
||||
parser.add_argument('--warmup', default=800, type=int, help='warmup for learning rate')
|
||||
parser.add_argument('--grad_clip', default=1.0, type=float, help='gradient clipping')
|
||||
|
@ -119,7 +120,7 @@ def parse():
|
|||
f'{args.world_size}g',
|
||||
args.commit[:7])
|
||||
args.dist_sync_file = os.path.join(args.log_dir, 'distributed_sync_file')
|
||||
|
||||
|
||||
save_args(args)
|
||||
|
||||
return args
|
||||
|
|
|
@ -24,16 +24,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
def dp(args):
|
||||
return args.dropout_ratio if args.rnn_layers > 1 else 0.
|
||||
|
||||
if self.args.elmo:
|
||||
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
|
||||
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
|
||||
self.elmo = Elmo(options_file, weight_file, 3, dropout=0.0, do_layer_norm=False)
|
||||
elmo_params = get_trainable_params(self.elmo)
|
||||
for p in elmo_params:
|
||||
p.requires_grad = False
|
||||
elmo_dim = 1024
|
||||
self.project_elmo = Feedforward(elmo_dim, args.dimension)
|
||||
else:
|
||||
if self.args.glove_and_char:
|
||||
|
||||
self.encoder_embeddings = Embedding(field, args.dimension,
|
||||
dropout=args.dropout_ratio, project=not args.cove)
|
||||
|
@ -45,6 +36,18 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
p.requires_grad = False
|
||||
cove_dim = int(args.intermediate_cove) * 600 + int(args.cove) * 600 + 400 # the last 400 is for GloVe and char n-gram embeddings
|
||||
self.project_cove = Feedforward(cove_dim, args.dimension)
|
||||
|
||||
if -1 not in self.args.elmo:
|
||||
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
|
||||
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
|
||||
self.elmo = Elmo(options_file, weight_file, 3, dropout=0.0, do_layer_norm=False)
|
||||
elmo_params = get_trainable_params(self.elmo)
|
||||
for p in elmo_params:
|
||||
p.requires_grad = False
|
||||
elmo_dim = 1024 * len(self.args.elmo)
|
||||
self.project_elmo = Feedforward(elmo_dim, args.dimension)
|
||||
if self.args.glove_and_char:
|
||||
self.project_embeddings = Feedforward(2 * args.dimension, args.dimension, dropout=0.0)
|
||||
|
||||
self.decoder_embeddings = Embedding(field, args.dimension,
|
||||
dropout=args.dropout_ratio, project=True)
|
||||
|
@ -93,17 +96,24 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
return limited_idx_to_full_idx[x]
|
||||
self.map_to_full = map_to_full
|
||||
|
||||
if self.args.elmo:
|
||||
def elmo(z, device):
|
||||
return self.elmo(batch_to_ids(z).to(device))['elmo_representations'][-1]
|
||||
context_embedded = self.project_elmo(elmo(context_elmo, context.device).detach())
|
||||
question_embedded = self.project_elmo(elmo(question_elmo, question.device).detach())
|
||||
else:
|
||||
if -1 not in self.args.elmo:
|
||||
def elmo(z, layers, device):
|
||||
e = self.elmo(batch_to_ids(z).to(device))['elmo_representations']
|
||||
return torch.cat([e[x] for x in layers], -1)
|
||||
context_elmo = self.project_elmo(elmo(context_elmo, self.args.elmo, context.device).detach())
|
||||
question_elmo = self.project_elmo(elmo(question_elmo, self.args.elmo, question.device).detach())
|
||||
|
||||
if self.args.glove_and_char:
|
||||
context_embedded = self.encoder_embeddings(context)
|
||||
question_embedded = self.encoder_embeddings(question)
|
||||
if self.args.cove:
|
||||
context_embedded = self.project_cove(torch.cat([self.cove(context_embedded[:, :, -300:], context_lengths), context_embedded], -1).detach())
|
||||
question_embedded = self.project_cove(torch.cat([self.cove(question_embedded[:, :, -300:], question_lengths), question_embedded], -1).detach())
|
||||
if -1 not in self.args.elmo:
|
||||
context_embedded = self.project_embeddings(torch.cat([context_embedded, context_elmo], -1))
|
||||
question_embedded = self.project_embeddings(torch.cat([question_embedded, question_elmo], -1))
|
||||
else:
|
||||
context_embedded, question_embedded = context_elmo, question_elmo
|
||||
|
||||
context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0]
|
||||
question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0]
|
||||
|
|
Loading…
Reference in New Issue