updates
This commit is contained in:
parent
ac7d066fca
commit
9c5e7c31db
2
train.py
2
train.py
|
@ -161,7 +161,7 @@ def train(args, model, opt, train_iters, train_iterations, field, rank=0, world_
|
|||
# by training them and not others
|
||||
# once the specified number of rounds is completed,
|
||||
# switch to normal round robin training
|
||||
if rnd<args.jump_start:
|
||||
if rnd < args.jump_start:
|
||||
train_iterations = [0]*len(train_iterations)
|
||||
for _ in range(args.n_jump_start): train_iterations[_] = 1
|
||||
else:
|
||||
|
|
|
@ -87,7 +87,7 @@ def validate(task, val_iter, model, logger, field, world_size, rank, num_print=1
|
|||
if hasattr(val_iter.dataset.examples[0], 'wikisql_id') or hasattr(val_iter.dataset.examples[0], 'squad_id') or hasattr(val_iter.dataset.examples[0], 'woz_id'):
|
||||
answers = [val_iter.dataset.all_answers[sid] for sid in answers.tolist()]
|
||||
metrics, answers = compute_metrics(predictions, answers, bleu='iwslt' in task or 'multi30k' in task or 'almond' in task, dialogue='woz' in task,
|
||||
rouge='cnn' in task, logical_form='sql' in task, corpus_f1='zre' in task,func_accuracy='almond' in task, args=args)
|
||||
rouge='cnn' in task, logical_form='sql' in task, corpus_f1='zre' in task,func_accuracy='almond' in task and not args.reverse_task_bool, args=args)
|
||||
results = [predictions, answers] + results
|
||||
print_results(names, results, rank=rank, num_print=num_print)
|
||||
|
||||
|
|
Loading…
Reference in New Issue