Remove max-margin loss

It doesn't work
This commit is contained in:
Giovanni Campagna 2020-01-14 09:58:12 -08:00
parent 1b891f38c1
commit c4a9c49d48
4 changed files with 4 additions and 29 deletions

View File

@ -126,7 +126,6 @@ def parse(argv):
parser.add_argument('--skip_cache', action='store_true', dest='skip_cache_bool', help='whether to use exisiting cached splits or generate new ones')
parser.add_argument('--lr_rate', default=0.001, type=float, help='initial_learning_rate')
parser.add_argument('--use_maxmargin_loss', action='store_true', help='whether to use max-margin loss or not')
parser.add_argument('--small_glove', action='store_true', help='Use glove.6B.50d instead of glove.840B.300d')
parser.add_argument('--almond_type_embeddings', action='store_true', help='Add type-based word embeddings for Almond task')
parser.add_argument('--use_curriculum', action='store_true', help='Use curriculum learning')

View File

@ -72,25 +72,6 @@ class LSTMDecoder(nn.Module):
return input, (h_1, c_1)
def max_margin_loss(probs, targets, pad_idx=1):
batch_size, max_length, depth = probs.size()
targets_mask = (targets != pad_idx).float()
flat_mask = targets_mask.view(batch_size*max_length,)
flat_preds = probs.view(batch_size*max_length, depth)
one_hot = torch.zeros_like(probs)
one_hot_gold = one_hot.scatter_(2, targets.unsqueeze(2), 1)
marginal_scores = probs - one_hot_gold + 1
marginal_scores = marginal_scores.view(batch_size*max_length, depth)
max_margin = torch.max(marginal_scores, dim=1)[0]
gold_score = torch.masked_select(flat_preds, one_hot_gold.view(batch_size*max_length, depth).byte())
margin = max_margin - gold_score
return torch.sum(margin*flat_mask) + 1e-8
def positional_encodings_like(x, t=None):
if t is None:

View File

@ -192,13 +192,8 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
decoder_vocab)
if self.args.use_maxmargin_loss:
targets = answer_indices[:, 1:].contiguous()
loss = max_margin_loss(probs, targets, pad_idx=pad_idx)
else:
probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=pad_idx)
loss = F.nll_loss(probs.log(), targets)
probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=pad_idx)
loss = F.nll_loss(probs.log(), targets)
return loss, None
else:

View File

@ -196,7 +196,7 @@ def load_config_json(args):
retrieve = ['model', 'transformer_layers', 'rnn_layers', 'transformer_hidden', 'dimension',
'load', 'max_val_context_length', 'val_batch_size', 'transformer_heads', 'max_output_length',
'max_generative_vocab', 'lower', 'glove_and_char',
'use_maxmargin_loss', 'small_glove', 'almond_type_embeddings', 'almond_grammar',
'small_glove', 'almond_type_embeddings', 'almond_grammar',
'trainable_decoder_embedding', 'glove_decoder', 'pretrained_decoder_lm',
'retrain_encoder_embedding', 'question', 'locale', 'use_google_translate']
@ -205,7 +205,7 @@ def load_config_json(args):
setattr(args, r, config[r])
elif r == 'locale':
setattr(args, r, 'en')
elif r in ('use_maxmargin_loss', 'small_glove', 'almond_type_embbedings'):
elif r in ('small_glove', 'almond_type_embbedings'):
setattr(args, r, False)
elif r in ('glove_decoder', 'glove_and_char'):
setattr(args, r, True)