From c4a9c49d48c4ae710380d721ab619836000e9c0d Mon Sep 17 00:00:00 2001 From: Giovanni Campagna Date: Tue, 14 Jan 2020 09:58:12 -0800 Subject: [PATCH] Remove max-margin loss It doesn't work --- decanlp/arguments.py | 1 - decanlp/models/common.py | 19 ------------------- .../multitask_question_answering_network.py | 9 ++------- decanlp/util.py | 4 ++-- 4 files changed, 4 insertions(+), 29 deletions(-) diff --git a/decanlp/arguments.py b/decanlp/arguments.py index bd8c14ce..6506257b 100644 --- a/decanlp/arguments.py +++ b/decanlp/arguments.py @@ -126,7 +126,6 @@ def parse(argv): parser.add_argument('--skip_cache', action='store_true', dest='skip_cache_bool', help='whether to use exisiting cached splits or generate new ones') parser.add_argument('--lr_rate', default=0.001, type=float, help='initial_learning_rate') - parser.add_argument('--use_maxmargin_loss', action='store_true', help='whether to use max-margin loss or not') parser.add_argument('--small_glove', action='store_true', help='Use glove.6B.50d instead of glove.840B.300d') parser.add_argument('--almond_type_embeddings', action='store_true', help='Add type-based word embeddings for Almond task') parser.add_argument('--use_curriculum', action='store_true', help='Use curriculum learning') diff --git a/decanlp/models/common.py b/decanlp/models/common.py index c4a4dda4..3d9b9042 100644 --- a/decanlp/models/common.py +++ b/decanlp/models/common.py @@ -72,25 +72,6 @@ class LSTMDecoder(nn.Module): return input, (h_1, c_1) -def max_margin_loss(probs, targets, pad_idx=1): - - batch_size, max_length, depth = probs.size() - targets_mask = (targets != pad_idx).float() - flat_mask = targets_mask.view(batch_size*max_length,) - flat_preds = probs.view(batch_size*max_length, depth) - - one_hot = torch.zeros_like(probs) - one_hot_gold = one_hot.scatter_(2, targets.unsqueeze(2), 1) - - marginal_scores = probs - one_hot_gold + 1 - marginal_scores = marginal_scores.view(batch_size*max_length, depth) - max_margin = torch.max(marginal_scores, dim=1)[0] - - gold_score = torch.masked_select(flat_preds, one_hot_gold.view(batch_size*max_length, depth).byte()) - margin = max_margin - gold_score - - return torch.sum(margin*flat_mask) + 1e-8 - def positional_encodings_like(x, t=None): if t is None: diff --git a/decanlp/models/multitask_question_answering_network.py b/decanlp/models/multitask_question_answering_network.py index a2b5583f..ebabbb56 100644 --- a/decanlp/models/multitask_question_answering_network.py +++ b/decanlp/models/multitask_question_answering_network.py @@ -192,13 +192,8 @@ class MultitaskQuestionAnsweringNetwork(nn.Module): decoder_vocab) - if self.args.use_maxmargin_loss: - targets = answer_indices[:, 1:].contiguous() - loss = max_margin_loss(probs, targets, pad_idx=pad_idx) - - else: - probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=pad_idx) - loss = F.nll_loss(probs.log(), targets) + probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=pad_idx) + loss = F.nll_loss(probs.log(), targets) return loss, None else: diff --git a/decanlp/util.py b/decanlp/util.py index b7ac28e6..b9ef6afc 100644 --- a/decanlp/util.py +++ b/decanlp/util.py @@ -196,7 +196,7 @@ def load_config_json(args): retrieve = ['model', 'transformer_layers', 'rnn_layers', 'transformer_hidden', 'dimension', 'load', 'max_val_context_length', 'val_batch_size', 'transformer_heads', 'max_output_length', 'max_generative_vocab', 'lower', 'glove_and_char', - 'use_maxmargin_loss', 'small_glove', 'almond_type_embeddings', 'almond_grammar', + 'small_glove', 'almond_type_embeddings', 'almond_grammar', 'trainable_decoder_embedding', 'glove_decoder', 'pretrained_decoder_lm', 'retrain_encoder_embedding', 'question', 'locale', 'use_google_translate'] @@ -205,7 +205,7 @@ def load_config_json(args): setattr(args, r, config[r]) elif r == 'locale': setattr(args, r, 'en') - elif r in ('use_maxmargin_loss', 'small_glove', 'almond_type_embbedings'): + elif r in ('small_glove', 'almond_type_embbedings'): setattr(args, r, False) elif r in ('glove_decoder', 'glove_and_char'): setattr(args, r, True)