parent
1b891f38c1
commit
c4a9c49d48
|
@ -126,7 +126,6 @@ def parse(argv):
|
|||
|
||||
parser.add_argument('--skip_cache', action='store_true', dest='skip_cache_bool', help='whether to use exisiting cached splits or generate new ones')
|
||||
parser.add_argument('--lr_rate', default=0.001, type=float, help='initial_learning_rate')
|
||||
parser.add_argument('--use_maxmargin_loss', action='store_true', help='whether to use max-margin loss or not')
|
||||
parser.add_argument('--small_glove', action='store_true', help='Use glove.6B.50d instead of glove.840B.300d')
|
||||
parser.add_argument('--almond_type_embeddings', action='store_true', help='Add type-based word embeddings for Almond task')
|
||||
parser.add_argument('--use_curriculum', action='store_true', help='Use curriculum learning')
|
||||
|
|
|
@ -72,25 +72,6 @@ class LSTMDecoder(nn.Module):
|
|||
|
||||
return input, (h_1, c_1)
|
||||
|
||||
def max_margin_loss(probs, targets, pad_idx=1):
|
||||
|
||||
batch_size, max_length, depth = probs.size()
|
||||
targets_mask = (targets != pad_idx).float()
|
||||
flat_mask = targets_mask.view(batch_size*max_length,)
|
||||
flat_preds = probs.view(batch_size*max_length, depth)
|
||||
|
||||
one_hot = torch.zeros_like(probs)
|
||||
one_hot_gold = one_hot.scatter_(2, targets.unsqueeze(2), 1)
|
||||
|
||||
marginal_scores = probs - one_hot_gold + 1
|
||||
marginal_scores = marginal_scores.view(batch_size*max_length, depth)
|
||||
max_margin = torch.max(marginal_scores, dim=1)[0]
|
||||
|
||||
gold_score = torch.masked_select(flat_preds, one_hot_gold.view(batch_size*max_length, depth).byte())
|
||||
margin = max_margin - gold_score
|
||||
|
||||
return torch.sum(margin*flat_mask) + 1e-8
|
||||
|
||||
|
||||
def positional_encodings_like(x, t=None):
|
||||
if t is None:
|
||||
|
|
|
@ -192,13 +192,8 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
decoder_vocab)
|
||||
|
||||
|
||||
if self.args.use_maxmargin_loss:
|
||||
targets = answer_indices[:, 1:].contiguous()
|
||||
loss = max_margin_loss(probs, targets, pad_idx=pad_idx)
|
||||
|
||||
else:
|
||||
probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=pad_idx)
|
||||
loss = F.nll_loss(probs.log(), targets)
|
||||
probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=pad_idx)
|
||||
loss = F.nll_loss(probs.log(), targets)
|
||||
return loss, None
|
||||
|
||||
else:
|
||||
|
|
|
@ -196,7 +196,7 @@ def load_config_json(args):
|
|||
retrieve = ['model', 'transformer_layers', 'rnn_layers', 'transformer_hidden', 'dimension',
|
||||
'load', 'max_val_context_length', 'val_batch_size', 'transformer_heads', 'max_output_length',
|
||||
'max_generative_vocab', 'lower', 'glove_and_char',
|
||||
'use_maxmargin_loss', 'small_glove', 'almond_type_embeddings', 'almond_grammar',
|
||||
'small_glove', 'almond_type_embeddings', 'almond_grammar',
|
||||
'trainable_decoder_embedding', 'glove_decoder', 'pretrained_decoder_lm',
|
||||
'retrain_encoder_embedding', 'question', 'locale', 'use_google_translate']
|
||||
|
||||
|
@ -205,7 +205,7 @@ def load_config_json(args):
|
|||
setattr(args, r, config[r])
|
||||
elif r == 'locale':
|
||||
setattr(args, r, 'en')
|
||||
elif r in ('use_maxmargin_loss', 'small_glove', 'almond_type_embbedings'):
|
||||
elif r in ('small_glove', 'almond_type_embbedings'):
|
||||
setattr(args, r, False)
|
||||
elif r in ('glove_decoder', 'glove_and_char'):
|
||||
setattr(args, r, True)
|
||||
|
|
Loading…
Reference in New Issue