more dynamic elmo use cases
This commit is contained in:
parent
37003d2a0a
commit
9a4493b655
|
@ -65,7 +65,8 @@ def parse():
|
||||||
parser.add_argument('--no_transformer_lr', action='store_false', dest='transformer_lr', help='turns off the transformer learning rate strategy')
|
parser.add_argument('--no_transformer_lr', action='store_false', dest='transformer_lr', help='turns off the transformer learning rate strategy')
|
||||||
parser.add_argument('--cove', action='store_true', help='whether to use contextualized word vectors (McCann et al. 2017)')
|
parser.add_argument('--cove', action='store_true', help='whether to use contextualized word vectors (McCann et al. 2017)')
|
||||||
parser.add_argument('--intermediate_cove', action='store_true', help='whether to use the intermediate layers of contextualized word vectors (McCann et al. 2017)')
|
parser.add_argument('--intermediate_cove', action='store_true', help='whether to use the intermediate layers of contextualized word vectors (McCann et al. 2017)')
|
||||||
parser.add_argument('--elmo', action='store_true', help='whether to use deep contextualized word vectors (Peters et al. 2018)')
|
parser.add_argument('--elmo', default=[-1], nargs='+', type=int, help='which layer(s) (0, 1, or 2) of ELMo (Peters et al. 2018) to use; -1 for none ')
|
||||||
|
parser.add_argument('--no_glove_and_char', action='store_false', dest='glove_and_char', help='turn off GloVe and CharNGram embeddings')
|
||||||
|
|
||||||
parser.add_argument('--warmup', default=800, type=int, help='warmup for learning rate')
|
parser.add_argument('--warmup', default=800, type=int, help='warmup for learning rate')
|
||||||
parser.add_argument('--grad_clip', default=1.0, type=float, help='gradient clipping')
|
parser.add_argument('--grad_clip', default=1.0, type=float, help='gradient clipping')
|
||||||
|
@ -119,7 +120,7 @@ def parse():
|
||||||
f'{args.world_size}g',
|
f'{args.world_size}g',
|
||||||
args.commit[:7])
|
args.commit[:7])
|
||||||
args.dist_sync_file = os.path.join(args.log_dir, 'distributed_sync_file')
|
args.dist_sync_file = os.path.join(args.log_dir, 'distributed_sync_file')
|
||||||
|
|
||||||
save_args(args)
|
save_args(args)
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
|
@ -24,16 +24,7 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
||||||
def dp(args):
|
def dp(args):
|
||||||
return args.dropout_ratio if args.rnn_layers > 1 else 0.
|
return args.dropout_ratio if args.rnn_layers > 1 else 0.
|
||||||
|
|
||||||
if self.args.elmo:
|
if self.args.glove_and_char:
|
||||||
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
|
|
||||||
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
|
|
||||||
self.elmo = Elmo(options_file, weight_file, 3, dropout=0.0, do_layer_norm=False)
|
|
||||||
elmo_params = get_trainable_params(self.elmo)
|
|
||||||
for p in elmo_params:
|
|
||||||
p.requires_grad = False
|
|
||||||
elmo_dim = 1024
|
|
||||||
self.project_elmo = Feedforward(elmo_dim, args.dimension)
|
|
||||||
else:
|
|
||||||
|
|
||||||
self.encoder_embeddings = Embedding(field, args.dimension,
|
self.encoder_embeddings = Embedding(field, args.dimension,
|
||||||
dropout=args.dropout_ratio, project=not args.cove)
|
dropout=args.dropout_ratio, project=not args.cove)
|
||||||
|
@ -45,6 +36,18 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
cove_dim = int(args.intermediate_cove) * 600 + int(args.cove) * 600 + 400 # the last 400 is for GloVe and char n-gram embeddings
|
cove_dim = int(args.intermediate_cove) * 600 + int(args.cove) * 600 + 400 # the last 400 is for GloVe and char n-gram embeddings
|
||||||
self.project_cove = Feedforward(cove_dim, args.dimension)
|
self.project_cove = Feedforward(cove_dim, args.dimension)
|
||||||
|
|
||||||
|
if -1 not in self.args.elmo:
|
||||||
|
options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json"
|
||||||
|
weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5"
|
||||||
|
self.elmo = Elmo(options_file, weight_file, 3, dropout=0.0, do_layer_norm=False)
|
||||||
|
elmo_params = get_trainable_params(self.elmo)
|
||||||
|
for p in elmo_params:
|
||||||
|
p.requires_grad = False
|
||||||
|
elmo_dim = 1024 * len(self.args.elmo)
|
||||||
|
self.project_elmo = Feedforward(elmo_dim, args.dimension)
|
||||||
|
if self.args.glove_and_char:
|
||||||
|
self.project_embeddings = Feedforward(2 * args.dimension, args.dimension, dropout=0.0)
|
||||||
|
|
||||||
self.decoder_embeddings = Embedding(field, args.dimension,
|
self.decoder_embeddings = Embedding(field, args.dimension,
|
||||||
dropout=args.dropout_ratio, project=True)
|
dropout=args.dropout_ratio, project=True)
|
||||||
|
@ -93,17 +96,24 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
||||||
return limited_idx_to_full_idx[x]
|
return limited_idx_to_full_idx[x]
|
||||||
self.map_to_full = map_to_full
|
self.map_to_full = map_to_full
|
||||||
|
|
||||||
if self.args.elmo:
|
if -1 not in self.args.elmo:
|
||||||
def elmo(z, device):
|
def elmo(z, layers, device):
|
||||||
return self.elmo(batch_to_ids(z).to(device))['elmo_representations'][-1]
|
e = self.elmo(batch_to_ids(z).to(device))['elmo_representations']
|
||||||
context_embedded = self.project_elmo(elmo(context_elmo, context.device).detach())
|
return torch.cat([e[x] for x in layers], -1)
|
||||||
question_embedded = self.project_elmo(elmo(question_elmo, question.device).detach())
|
context_elmo = self.project_elmo(elmo(context_elmo, self.args.elmo, context.device).detach())
|
||||||
else:
|
question_elmo = self.project_elmo(elmo(question_elmo, self.args.elmo, question.device).detach())
|
||||||
|
|
||||||
|
if self.args.glove_and_char:
|
||||||
context_embedded = self.encoder_embeddings(context)
|
context_embedded = self.encoder_embeddings(context)
|
||||||
question_embedded = self.encoder_embeddings(question)
|
question_embedded = self.encoder_embeddings(question)
|
||||||
if self.args.cove:
|
if self.args.cove:
|
||||||
context_embedded = self.project_cove(torch.cat([self.cove(context_embedded[:, :, -300:], context_lengths), context_embedded], -1).detach())
|
context_embedded = self.project_cove(torch.cat([self.cove(context_embedded[:, :, -300:], context_lengths), context_embedded], -1).detach())
|
||||||
question_embedded = self.project_cove(torch.cat([self.cove(question_embedded[:, :, -300:], question_lengths), question_embedded], -1).detach())
|
question_embedded = self.project_cove(torch.cat([self.cove(question_embedded[:, :, -300:], question_lengths), question_embedded], -1).detach())
|
||||||
|
if -1 not in self.args.elmo:
|
||||||
|
context_embedded = self.project_embeddings(torch.cat([context_embedded, context_elmo], -1))
|
||||||
|
question_embedded = self.project_embeddings(torch.cat([question_embedded, question_elmo], -1))
|
||||||
|
else:
|
||||||
|
context_embedded, question_embedded = context_elmo, question_elmo
|
||||||
|
|
||||||
context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0]
|
context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0]
|
||||||
question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0]
|
question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0]
|
||||||
|
|
Loading…
Reference in New Issue