This commit is contained in:
mehrad 2018-11-06 14:13:27 -08:00
parent ac7d066fca
commit 9c5e7c31db
2 changed files with 2 additions and 2 deletions

View File

@ -161,7 +161,7 @@ def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_
# by training them and not others
# once the specified number of rounds is completed,
# switch to normal round robin training
if rnd<args.jump_start:
if rnd < args.jump_start:
train_iterations = [0]*len(train_iterations)
for _ in range(args.n_jump_start): train_iterations[_] = 1
else:

View File

@ -87,7 +87,7 @@ def validate(task, val_iter, model, logger, field, world_size, rank, num_print=1
if hasattr(val_iter.dataset.examples[0], 'wikisql_id') or hasattr(val_iter.dataset.examples[0], 'squad_id') or hasattr(val_iter.dataset.examples[0], 'woz_id'):
answers = [val_iter.dataset.all_answers[sid] for sid in answers.tolist()]
metrics, answers = compute_metrics(predictions, answers, bleu='iwslt' in task or 'multi30k' in task or 'almond' in task, dialogue='woz' in task,
rouge='cnn' in task, logical_form='sql' in task, corpus_f1='zre' in task,func_accuracy='almond' in task, args=args)
rouge='cnn' in task, logical_form='sql' in task, corpus_f1='zre' in task,func_accuracy='almond' in task and not args.reverse_task_bool, args=args)
results = [predictions, answers] + results
print_results(names, results, rank=rank, num_print=num_print)