Remove cove

Cove is an obsolete form of pretraining, of dubious utility. It
depends on an obsolete, unmaintained, library, and it causes
hacks in the code. Clean up.
This commit is contained in:
Giovanni Campagna 2019-12-10 16:34:35 -08:00
parent 6911a5092d
commit e7481840f7
7 changed files with 4 additions and 33 deletions

View File

@ -18,7 +18,6 @@ sacrebleu = "*"
orderedset = "*"
records = "*"
tabulate = "*"
cove = {editable = true,git = "git://github.com/salesforce/cove.git"}
allennlp = "*"
tensorboardX = "*"
Babel = "*"

View File

@ -100,8 +100,6 @@ def parse(argv):
parser.add_argument('--transformer_hidden', default=150, type=int, help='hidden size of the transformer modules')
parser.add_argument('--transformer_heads', default=3, type=int, help='number of heads for transformer modules')
parser.add_argument('--dropout_ratio', default=0.2, type=float, help='dropout for the model')
parser.add_argument('--cove', action='store_true', help='whether to use contextualized word vectors (McCann et al. 2017)')
parser.add_argument('--intermediate_cove', action='store_true', help='whether to use the intermediate layers of contextualized word vectors (McCann et al. 2017)')
parser.add_argument('--elmo', default=[-1], nargs='+', type=int, help='which layer(s) (0, 1, or 2) of ELMo (Peters et al. 2018) to use; -1 for none ')
parser.add_argument('--no_glove_and_char', action='store_false', dest='glove_and_char', help='turn off GloVe and CharNGram embeddings')
parser.add_argument('--locale', default='en', help='locale to use for word embeddings')

View File

@ -59,19 +59,10 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
self.encoder_embeddings = Embedding(field, args.dimension,
trained_dimension=0,
dropout=args.dropout_ratio, project=not args.cove,
dropout=args.dropout_ratio,
project=True,
requires_grad=args.retrain_encoder_embedding)
if self.args.cove or self.args.intermediate_cove:
from cove import MTLSTM
self.cove = MTLSTM(model_cache=args.embeddings, layer0=args.intermediate_cove, layer1=args.cove)
cove_params = get_trainable_params(self.cove)
for p in cove_params:
p.requires_grad = False
cove_dim = int(args.intermediate_cove) * 600 + int(args.cove) * 600 + 400 + int(args.almond_type_embeddings) * 18 # the last 400 is for GloVe and char n-gram embeddings
self.project_cove = Feedforward(cove_dim, args.dimension)
if -1 not in self.args.elmo:
from allennlp.modules.elmo import Elmo, batch_to_ids
@ -176,9 +167,6 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
if self.args.glove_and_char:
context_embedded = self.encoder_embeddings(context)
question_embedded = self.encoder_embeddings(question)
if self.args.cove:
context_embedded = self.project_cove(torch.cat([self.cove(context_embedded[:, :, 100:400], context_lengths), context_embedded], -1).detach())
question_embedded = self.project_cove(torch.cat([self.cove(question_embedded[:, :, 100:400], question_lengths), question_embedded], -1).detach())
if -1 not in self.args.elmo:
context_embedded = self.project_embeddings(torch.cat([context_embedded, context_elmo], -1))
question_embedded = self.project_embeddings(torch.cat([question_embedded, question_elmo], -1))

View File

@ -296,12 +296,6 @@ def main(argv=sys.argv):
Model = getattr(models, args.model)
model = Model(field, args)
model_dict = save_dict['model_state_dict']
backwards_compatible_cove_dict = {}
for k, v in model_dict.items():
if 'cove.rnn.' in k:
k = k.replace('cove.rnn.', 'cove.rnn1.')
backwards_compatible_cove_dict[k] = v
model_dict = backwards_compatible_cove_dict
model.load_state_dict(model_dict)
field, splits = prepare_data(args, field)
if args.model != 'MultiLingualTranslationModel':

View File

@ -237,12 +237,6 @@ def main(argv=sys.argv):
Model = getattr(models, args.model)
model = Model(field, args)
model_dict = save_dict['model_state_dict']
backwards_compatible_cove_dict = {}
for k, v in model_dict.items():
if 'cove.rnn.' in k:
k = k.replace('cove.rnn.', 'cove.rnn1.')
backwards_compatible_cove_dict[k] = v
model_dict = backwards_compatible_cove_dict
model.load_state_dict(model_dict)
server = Server(args, field, model)

View File

@ -170,7 +170,7 @@ def load_config_json(args):
config = json.load(config_file)
retrieve = ['model', 'transformer_layers', 'rnn_layers', 'transformer_hidden', 'world_size', 'dimension',
'load', 'max_val_context_length', 'val_batch_size', 'transformer_heads', 'max_output_length',
'max_generative_vocab', 'lower', 'cove', 'intermediate_cove', 'elmo', 'glove_and_char',
'max_generative_vocab', 'lower', 'elmo', 'glove_and_char',
'use_maxmargin_loss', 'small_glove', 'almond_type_embeddings', 'almond_grammar',
'trainable_decoder_embedding', 'glove_decoder', 'pretrained_decoder_lm',
'retrain_encoder_embedding', 'question', 'locale', 'use_google_translate']
@ -180,8 +180,7 @@ def load_config_json(args):
setattr(args, r, config[r])
elif r == 'locale':
setattr(args, r, 'en')
elif r in ('cove', 'intermediate_cove', 'use_maxmargin_loss', 'small_glove',
'almond_type_embbedings'):
elif r in ('use_maxmargin_loss', 'small_glove', 'almond_type_embbedings'):
setattr(args, r, False)
elif 'elmo' in r:
setattr(args, r, [-1])

View File

@ -19,5 +19,4 @@ sacrebleu
orderedset
requests
-e git+git://github.com/salesforce/cove.git#egg=cove
allennlp