updates and fixes
This commit is contained in:
parent
fdfdd154c4
commit
8c54dc1391
|
@ -25,7 +25,7 @@ def parse():
|
|||
Returns the arguments from the command line.
|
||||
"""
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--root', default='/decaNLP', type=str, help='root directory for data, results, embeddings, code, etc.')
|
||||
parser.add_argument('--root', default='./decaNLP', type=str, help='root directory for data, results, embeddings, code, etc.')
|
||||
parser.add_argument('--data', default='.data/', type=str, help='where to load data from.')
|
||||
parser.add_argument('--save', default='results', type=str, help='where to save results.')
|
||||
parser.add_argument('--embeddings', default='.embeddings', type=str, help='where to save embeddings.')
|
||||
|
|
|
@ -11,7 +11,7 @@ from util import get_trainable_params, set_seed
|
|||
from modules import expectedBLEU, expectedMultiBleu, matrixBLEU
|
||||
|
||||
from cove import MTLSTM
|
||||
#from allennlp.modules.elmo import Elmo, batch_to_ids
|
||||
from allennlp.modules.elmo import Elmo, batch_to_ids
|
||||
|
||||
from .common import positional_encodings_like, INF, EPSILON, TransformerEncoder, TransformerDecoder, PackedLSTM, LSTMDecoderAttention, LSTMDecoder, Embedding, Feedforward, mask, CoattentiveLayer
|
||||
|
||||
|
@ -119,29 +119,6 @@ class MultitaskQuestionAnsweringNetwork(nn.Module):
|
|||
else:
|
||||
context_embedded, question_embedded = context_elmo, question_elmo
|
||||
|
||||
elif self.args.elmo:
|
||||
context_list = []
|
||||
for b in range(context.size(0)):
|
||||
#FIXME
|
||||
lc = [self.field.decoder_itos[i] if i < self.args.max_generative_vocab else self.field.decoder_itos[0] for i in context[b, :]]
|
||||
if lc[0] == self.field.decoder_itos[2]:
|
||||
lc[0] = '<S>'
|
||||
if lc[-1] == self.field.decoder_itos[3]:
|
||||
lc[-1] = '</S>'
|
||||
context_list.append(lc)
|
||||
question_list = []
|
||||
for b in range(question.size(0)):
|
||||
lq = [self.field.decoder_itos[i] if i < self.args.max_generative_vocab else self.field.decoder_itos[0] for i in question[b, :]]
|
||||
if lq[0] == self.field.decoder_itos[2]:
|
||||
lq[0] = '<S>'
|
||||
if lq[-1] == self.field.decoder_itos[3]:
|
||||
lq[-1] = '</S>'
|
||||
question_list.append(lq)
|
||||
|
||||
context_embedded = self.project_elmo(torch.cat([self.elmo(batch_to_ids(context_list).to(self.device))['elmo_representations'][0], context_embedded], -1).detach())
|
||||
question_embedded = self.project_elmo(torch.cat([self.elmo(batch_to_ids(question_list).to(self.device))['elmo_representations'][0], question_embedded], -1).detach())
|
||||
|
||||
|
||||
context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0]
|
||||
question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0]
|
||||
|
||||
|
|
|
@ -15,7 +15,6 @@ import models
|
|||
from text.torchtext.data.utils import get_tokenizer
|
||||
|
||||
|
||||
|
||||
def get_all_splits(args, new_vocab):
|
||||
splits = []
|
||||
for task in args.tasks:
|
||||
|
@ -171,11 +170,7 @@ def run(args, field, val_sets, model):
|
|||
a = from_all_answers(batch.woz_id.data.cpu())
|
||||
else:
|
||||
if task == 'almond':
|
||||
setattr(field, 'use_revtok', False)
|
||||
setattr(field, 'tokenize', tokenizer)
|
||||
a = field.reverse_almond(batch.answer.data)
|
||||
setattr(field, 'use_revtok', True)
|
||||
setattr(field, 'tokenize', 'revtok')
|
||||
a = field.reverse(batch.answer.data, detokenize=lambda x: ' '.join(x))
|
||||
else:
|
||||
a = field.reverse(batch.answer.data)
|
||||
for aa in a:
|
||||
|
@ -322,7 +317,6 @@ if __name__ == '__main__':
|
|||
|
||||
print(f'Loading from {args.best_checkpoint}')
|
||||
|
||||
# save_dict = torch.load(args.best_checkpoint)
|
||||
if torch.cuda.is_available():
|
||||
save_dict = torch.load(args.best_checkpoint)
|
||||
else:
|
||||
|
|
2
train.py
2
train.py
|
@ -84,7 +84,7 @@ def prepare_data(args, field, logger):
|
|||
|
||||
for task, s in zip(args.train_tasks, train_sets):
|
||||
for ex in s.examples[:10]:
|
||||
print('examples***:', ex.context)
|
||||
print('examples***:', [token.strip() for token in ex.context])
|
||||
|
||||
if args.load is None:
|
||||
logger.info(f'Getting pretrained word vectors')
|
||||
|
|
8
util.py
8
util.py
|
@ -64,10 +64,10 @@ def preprocess_examples(args, tasks, splits, field, logger=None, train=True):
|
|||
if logger is not None:
|
||||
logger.info('Tokenized examples:')
|
||||
for ex in s.examples[:10]:
|
||||
logger.info('Context: ' + ' '.join(ex.context))
|
||||
logger.info('Question: ' + ' '.join(ex.question))
|
||||
logger.info(' '.join(ex.context_question))
|
||||
logger.info('Answer: ' + ' '.join(ex.answer))
|
||||
logger.info('Context: ' + ' '.join([token.strip() for token in ex.context]))
|
||||
logger.info('Question: ' + ' '.join([token.strip() for token in ex.question]))
|
||||
logger.info(' '.join([token.strip() for token in ex.context_question]))
|
||||
logger.info('Answer: ' + ' '.join([token.strip() for token in ex.answer]))
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ def all_reverse(tensor, world_size, task, field, clip, dim=0):
|
|||
# for distributed training, dev sets are padded with extra examples so that the
|
||||
# tensors are all of a predictable size for all_gather. This line removes those extra examples
|
||||
if task == 'almond':
|
||||
return field.reverse(tensor, detokenizer=lambda x: ' '.join(x))[:clip]
|
||||
return field.reverse(tensor, detokenize=lambda x: ' '.join(x))[:clip]
|
||||
else:
|
||||
return field.reverse(tensor)[:clip]
|
||||
|
||||
|
|
Loading…
Reference in New Issue