updates and fixes

This commit is contained in:
mehrad 2019-01-23 16:41:37 -08:00
parent fdfdd154c4
commit 8c54dc1391
6 changed files with 9 additions and 38 deletions

View File

@ -25,7 +25,7 @@ def parse():
Returns the arguments from the command line.
"""
parser = ArgumentParser()
parser.add_argument('--root', default='/decaNLP', type=str, help='root directory for data, results, embeddings, code, etc.')
parser.add_argument('--root', default='./decaNLP', type=str, help='root directory for data, results, embeddings, code, etc.')
parser.add_argument('--data', default='.data/', type=str, help='where to load data from.')
parser.add_argument('--save', default='results', type=str, help='where to save results.')
parser.add_argument('--embeddings', default='.embeddings', type=str, help='where to save embeddings.')

View File

@ -11,7 +11,7 @@ from util import get_trainable_params, set_seed
from modules import expectedBLEU, expectedMultiBleu, matrixBLEU
from cove import MTLSTM
#from allennlp.modules.elmo import Elmo, batch_to_ids
from allennlp.modules.elmo import Elmo, batch_to_ids
from .common import positional_encodings_like, INF, EPSILON, TransformerEncoder, TransformerDecoder, PackedLSTM, LSTMDecoderAttention, LSTMDecoder, Embedding, Feedforward, mask, CoattentiveLayer
@ -119,29 +119,6 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
else:
context_embedded, question_embedded = context_elmo, question_elmo
elif self.args.elmo:
context_list = []
for b in range(context.size(0)):
#FIXME
lc = [self.field.decoder_itos[i] if i < self.args.max_generative_vocab else self.field.decoder_itos[0] for i in context[b, :]]
if lc[0] == self.field.decoder_itos[2]:
lc[0] = '<S>'
if lc[-1] == self.field.decoder_itos[3]:
lc[-1] = '</S>'
context_list.append(lc)
question_list = []
for b in range(question.size(0)):
lq = [self.field.decoder_itos[i] if i < self.args.max_generative_vocab else self.field.decoder_itos[0] for i in question[b, :]]
if lq[0] == self.field.decoder_itos[2]:
lq[0] = '<S>'
if lq[-1] == self.field.decoder_itos[3]:
lq[-1] = '</S>'
question_list.append(lq)
context_embedded = self.project_elmo(torch.cat([self.elmo(batch_to_ids(context_list).to(self.device))['elmo_representations'][0], context_embedded], -1).detach())
question_embedded = self.project_elmo(torch.cat([self.elmo(batch_to_ids(question_list).to(self.device))['elmo_representations'][0], question_embedded], -1).detach())
context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0]
question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0]

View File

@ -15,7 +15,6 @@ import models
from text.torchtext.data.utils import get_tokenizer
def get_all_splits(args, new_vocab):
splits = []
for task in args.tasks:
@ -171,11 +170,7 @@ def run(args, field, val_sets, model):
a = from_all_answers(batch.woz_id.data.cpu())
else:
if task == 'almond':
setattr(field, 'use_revtok', False)
setattr(field, 'tokenize', tokenizer)
a = field.reverse_almond(batch.answer.data)
setattr(field, 'use_revtok', True)
setattr(field, 'tokenize', 'revtok')
a = field.reverse(batch.answer.data, detokenize=lambda x: ' '.join(x))
else:
a = field.reverse(batch.answer.data)
for aa in a:
@ -322,7 +317,6 @@ if __name__ == '__main__':
print(f'Loading from {args.best_checkpoint}')
# save_dict = torch.load(args.best_checkpoint)
if torch.cuda.is_available():
save_dict = torch.load(args.best_checkpoint)
else:

View File

@ -84,7 +84,7 @@ def prepare_data(args, field, logger):
for task, s in zip(args.train_tasks, train_sets):
for ex in s.examples[:10]:
print('examples***:', ex.context)
print('examples***:', [token.strip() for token in ex.context])
if args.load is None:
logger.info(f'Getting pretrained word vectors')

View File

@ -64,10 +64,10 @@ def preprocess_examples(args, tasks, splits, field, logger=None, train=True):
if logger is not None:
logger.info('Tokenized examples:')
for ex in s.examples[:10]:
logger.info('Context: ' + ' '.join(ex.context))
logger.info('Question: ' + ' '.join(ex.question))
logger.info(' '.join(ex.context_question))
logger.info('Answer: ' + ' '.join(ex.answer))
logger.info('Context: ' + ' '.join([token.strip() for token in ex.context]))
logger.info('Question: ' + ' '.join([token.strip() for token in ex.question]))
logger.info(' '.join([token.strip() for token in ex.context_question]))
logger.info('Answer: ' + ' '.join([token.strip() for token in ex.answer]))

View File

@ -44,7 +44,7 @@ def all_reverse(tensor, world_size, task, field, clip, dim=0):
# for distributed training, dev sets are padded with extra examples so that the
# tensors are all of a predictable size for all_gather. This line removes those extra examples
if task == 'almond':
return field.reverse(tensor, detokenizer=lambda x: ' '.join(x))[:clip]
return field.reverse(tensor, detokenize=lambda x: ' '.join(x))[:clip]
else:
return field.reverse(tensor)[:clip]