2018-06-20 06:22:34 +00:00
|
|
|
import torch
|
2018-11-07 23:06:41 +00:00
|
|
|
from util import pad, tokenizer
|
2018-06-20 06:22:34 +00:00
|
|
|
from metrics import compute_metrics
|
2018-11-07 23:06:41 +00:00
|
|
|
from text.torchtext.data.utils import get_tokenizer
|
|
|
|
|
2018-06-20 06:22:34 +00:00
|
|
|
|
2018-11-27 23:22:38 +00:00
|
|
|
def compute_validation_outputs(model, val_iter, field, iteration, optional_names=[]):
|
2018-06-20 06:22:34 +00:00
|
|
|
loss, predictions, answers = [], [], []
|
|
|
|
outputs = [[] for _ in range(len(optional_names))]
|
|
|
|
for batch_idx, batch in enumerate(val_iter):
|
2018-11-27 23:22:38 +00:00
|
|
|
l, p = model(batch, iteration)
|
2018-06-20 06:22:34 +00:00
|
|
|
loss.append(l)
|
|
|
|
predictions.append(pad(p, 150, dim=-1, val=field.vocab.stoi['<pad>']))
|
|
|
|
a = None
|
|
|
|
if hasattr(batch, 'wikisql_id'):
|
|
|
|
a = batch.wikisql_id.data.cpu()
|
|
|
|
elif hasattr(batch, 'squad_id'):
|
|
|
|
a = batch.squad_id.data.cpu()
|
|
|
|
elif hasattr(batch, 'woz_id'):
|
|
|
|
a = batch.woz_id.data.cpu()
|
|
|
|
else:
|
2018-11-27 23:22:38 +00:00
|
|
|
a = pad(batch.answer.data.cpu(), 150, dim=-1, val=field.vocab.stoi['<pad>'])
|
2018-06-20 06:22:34 +00:00
|
|
|
answers.append(a)
|
|
|
|
for opt_idx, optional_name in enumerate(optional_names):
|
|
|
|
outputs[opt_idx].append(getattr(batch, optional_name).data.cpu())
|
|
|
|
loss = torch.cat(loss, 0) if loss[0] is not None else None
|
|
|
|
predictions = torch.cat(predictions, 0)
|
|
|
|
answers = torch.cat(answers, 0)
|
|
|
|
return loss, predictions, answers, [torch.cat([pad(x, 150, dim=-1, val=field.vocab.stoi['<pad>']) for x in output], 0) for output in outputs]
|
|
|
|
|
|
|
|
|
|
|
|
def get_clip(val_iter):
|
|
|
|
return -val_iter.extra if val_iter.extra > 0 else None
|
|
|
|
|
|
|
|
|
2018-11-07 23:06:41 +00:00
|
|
|
def all_reverse(tensor, world_size, task, field, clip, dim=0):
|
2018-06-20 06:22:34 +00:00
|
|
|
if world_size > 1:
|
|
|
|
tensor = tensor.float() # tensors must be on cpu and float for all_gather
|
|
|
|
all_tensors = [torch.zeros_like(tensor) for _ in range(world_size)]
|
|
|
|
torch.distributed.barrier() # all_gather is experimental for gloo, found that these barriers were necessary
|
|
|
|
torch.distributed.all_gather(all_tensors, tensor)
|
|
|
|
torch.distributed.barrier()
|
|
|
|
tensor = torch.cat(all_tensors, 0).long() # tensors must be long for reverse
|
|
|
|
# for distributed training, dev sets are padded with extra examples so that the
|
|
|
|
# tensors are all of a predictable size for all_gather. This line removes those extra examples
|
2018-11-07 23:06:41 +00:00
|
|
|
if task == 'almond':
|
|
|
|
setattr(field, 'use_revtok', False)
|
|
|
|
setattr(field, 'tokenize', tokenizer)
|
|
|
|
value = field.reverse_almond(tensor)[:clip]
|
|
|
|
setattr(field, 'use_revtok', True)
|
|
|
|
setattr(field, 'tokenize', get_tokenizer('revtok'))
|
|
|
|
return value
|
|
|
|
else:
|
|
|
|
return field.reverse(tensor)[:clip]
|
2018-06-20 06:22:34 +00:00
|
|
|
|
|
|
|
|
2018-11-27 23:22:38 +00:00
|
|
|
def gather_results(model, val_iter, field, world_size, task, iteration, optional_names=[]):
|
|
|
|
loss, predictions, answers, outputs = compute_validation_outputs(model, val_iter, field, iteration, optional_names=optional_names)
|
2018-06-20 06:22:34 +00:00
|
|
|
clip = get_clip(val_iter)
|
|
|
|
if not hasattr(val_iter.dataset.examples[0], 'squad_id') and not hasattr(val_iter.dataset.examples[0], 'wikisql_id') and not hasattr(val_iter.dataset.examples[0], 'woz_id'):
|
2018-11-07 23:06:41 +00:00
|
|
|
answers = all_reverse(answers, world_size, task, field, clip)
|
|
|
|
return loss, all_reverse(predictions, world_size, task, field, clip), answers, [all_reverse(x, world_size, task, field, clip) for x in outputs]
|
2018-06-20 06:22:34 +00:00
|
|
|
|
|
|
|
|
|
|
|
def print_results(keys, values, rank=None, num_print=1):
|
|
|
|
print()
|
|
|
|
start = rank * num_print if rank is not None else 0
|
|
|
|
end = start + num_print
|
|
|
|
values = [val[start:end] for val in values]
|
|
|
|
for ex_idx in range(len(values[0])):
|
|
|
|
for key_idx, key in enumerate(keys):
|
|
|
|
value = values[key_idx][ex_idx]
|
|
|
|
v = value[0] if isinstance(value, list) else value
|
|
|
|
print(f'{key}: {repr(v)}')
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
2018-11-27 23:22:38 +00:00
|
|
|
def validate(task, val_iter, model, logger, field, world_size, rank, iteration, num_print=10, args=None):
|
2018-09-17 18:08:35 +00:00
|
|
|
with torch.no_grad():
|
|
|
|
model.eval()
|
|
|
|
required_names = ['greedy', 'answer']
|
|
|
|
optional_names = ['context', 'question']
|
2018-11-27 23:22:38 +00:00
|
|
|
loss, predictions, answers, results = gather_results(model, val_iter, field, world_size, task, iteration, optional_names=optional_names)
|
2018-09-17 18:08:35 +00:00
|
|
|
predictions = [p.replace('UNK', 'OOV') for p in predictions]
|
|
|
|
names = required_names + optional_names
|
|
|
|
if hasattr(val_iter.dataset.examples[0], 'wikisql_id') or hasattr(val_iter.dataset.examples[0], 'squad_id') or hasattr(val_iter.dataset.examples[0], 'woz_id'):
|
|
|
|
answers = [val_iter.dataset.all_answers[sid] for sid in answers.tolist()]
|
2018-11-07 23:06:41 +00:00
|
|
|
metrics, answers = compute_metrics(predictions, answers,
|
|
|
|
bleu='iwslt' in task or 'multi30k' in task or 'almond' in task,
|
|
|
|
dialogue='woz' in task,
|
|
|
|
rouge='cnn' in task, logical_form='sql' in task, corpus_f1='zre' in task,
|
|
|
|
func_accuracy='almond' in task and not args.reverse_task_bool, args=args)
|
2018-09-17 18:08:35 +00:00
|
|
|
results = [predictions, answers] + results
|
|
|
|
print_results(names, results, rank=rank, num_print=num_print)
|
2018-06-20 06:22:34 +00:00
|
|
|
|
2018-09-17 18:08:35 +00:00
|
|
|
return loss, metrics
|