touching up predict
This commit is contained in:
parent
fb0dcaab35
commit
eab80b6cea
25
predict.py
25
predict.py
|
@ -64,6 +64,8 @@ def to_iter(data, bs, device):
|
|||
def run(args, field, val_sets, model):
|
||||
device = set_seed(args)
|
||||
print(f'Preparing iterators')
|
||||
if len(args.val_batch_size) == 1 and len(val_sets) > 1:
|
||||
args.val_batch_size *= len(val_sets)
|
||||
iters = [(name, to_iter(x, bs, device)) for name, x, bs in zip(args.tasks, val_sets, args.val_batch_size)]
|
||||
|
||||
def mult(ps):
|
||||
|
@ -82,10 +84,11 @@ def run(args, field, val_sets, model):
|
|||
model.eval()
|
||||
with torch.no_grad():
|
||||
for task, it in iters:
|
||||
print(task)
|
||||
prediction_file_name = os.path.join(os.path.splitext(args.best_checkpoint)[0], args.evaluate, task + '.txt')
|
||||
answer_file_name = os.path.join(os.path.splitext(args.best_checkpoint)[0], args.evaluate, task + '.gold.txt')
|
||||
results_file_name = answer_file_name.replace('gold', 'results')
|
||||
if 'sql' in task:
|
||||
if 'sql' in task or 'squad' in task:
|
||||
ids_file_name = answer_file_name.replace('gold', 'ids')
|
||||
if os.path.exists(prediction_file_name):
|
||||
print('** ', prediction_file_name, ' already exists -- this is where predictions are stored **')
|
||||
|
@ -104,14 +107,15 @@ def run(args, field, val_sets, model):
|
|||
if not os.path.exists(prediction_file_name) or args.overwrite_predictions:
|
||||
with open(prediction_file_name, 'w') as prediction_file:
|
||||
predictions = []
|
||||
wikisql_ids = []
|
||||
ids = []
|
||||
for batch_idx, batch in enumerate(it):
|
||||
_, p = model(batch)
|
||||
p = field.reverse(p)
|
||||
for i, pp in enumerate(p):
|
||||
if 'sql' in task:
|
||||
wikisql_id = int(batch.wikisql_id[i])
|
||||
wikisql_ids.append(wikisql_id)
|
||||
ids.append(int(batch.wikisql_id[i]))
|
||||
if 'squad' in task:
|
||||
ids.append(it.dataset.q_ids[int(batch.squad_id[i])])
|
||||
prediction_file.write(pp + '\n')
|
||||
predictions.append(pp)
|
||||
else:
|
||||
|
@ -120,9 +124,14 @@ def run(args, field, val_sets, model):
|
|||
|
||||
if 'sql' in task:
|
||||
with open(ids_file_name, 'w') as id_file:
|
||||
for i in wikisql_ids:
|
||||
for i in ids:
|
||||
id_file.write(json.dumps(i) + '\n')
|
||||
|
||||
|
||||
if 'squad' in task:
|
||||
with open(ids_file_name, 'w') as id_file:
|
||||
for i in ids:
|
||||
id_file.write(i + '\n')
|
||||
|
||||
def from_all_answers(an):
|
||||
return [it.dataset.all_answers[sid] for sid in an.tolist()]
|
||||
|
||||
|
@ -165,7 +174,7 @@ def get_args():
|
|||
parser = ArgumentParser()
|
||||
parser.add_argument('--path', required=True)
|
||||
parser.add_argument('--evaluate', type=str, required=True)
|
||||
parser.add_argument('--tasks', default=['wikisql', 'woz.en', 'cnn_dailymail', 'iwslt.en.de', 'zre', 'srl', 'squad', 'sst', 'multinli.in.out', 'schema'], nargs='+')
|
||||
parser.add_argument('--tasks', default=['squad', 'iwslt.en.de', 'cnn_dailymail', 'multinli.in.out', 'sst', 'srl', 'zre', 'woz.en', 'wikisql', 'schema'], nargs='+')
|
||||
parser.add_argument('--gpus', default=[0], nargs='+', type=int, help='a list of gpus that can be used (multi-gpu currently WIP)')
|
||||
parser.add_argument('--seed', default=123, type=int, help='Random seed.')
|
||||
parser.add_argument('--data', default='/decaNLP/.data/', type=str, help='where to load data from.')
|
||||
|
@ -180,7 +189,7 @@ def get_args():
|
|||
|
||||
with open(os.path.join(args.path, 'config.json')) as config_file:
|
||||
config = json.load(config_file)
|
||||
retrieve = ['model', 'val_batch_size',
|
||||
retrieve = ['model',
|
||||
'transformer_layers', 'rnn_layers', 'transformer_hidden',
|
||||
'dimension', 'load', 'max_val_context_length', 'val_batch_size',
|
||||
'transformer_heads', 'max_output_length', 'max_generative_vocab',
|
||||
|
|
|
@ -211,10 +211,10 @@ class SQuAD(CQA, data.Dataset):
|
|||
fields = [(x, field) for x in self.fields]
|
||||
cache_name = os.path.join(os.path.dirname(path), '.cache', os.path.basename(path), str(subsample))
|
||||
|
||||
examples, all_answers = [], []
|
||||
examples, all_answers, q_ids = [], [], []
|
||||
if os.path.exists(cache_name):
|
||||
print(f'Loading cached data from {cache_name}')
|
||||
examples, all_answers = torch.load(cache_name)
|
||||
examples, all_answers, q_ids = torch.load(cache_name)
|
||||
else:
|
||||
with open(os.path.expanduser(path)) as f:
|
||||
squad = json.load(f)['data']
|
||||
|
@ -226,6 +226,7 @@ class SQuAD(CQA, data.Dataset):
|
|||
qas = paragraph['qas']
|
||||
for qa in qas:
|
||||
question = ' '.join(qa['question'].split())
|
||||
q_ids.append(qa['id'])
|
||||
squad_id = len(all_answers)
|
||||
context_question = get_context_question(context, question)
|
||||
if len(qa['answers']) == 0:
|
||||
|
@ -303,7 +304,7 @@ class SQuAD(CQA, data.Dataset):
|
|||
|
||||
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
|
||||
print(f'Caching data to {cache_name}')
|
||||
torch.save((examples, all_answers), cache_name)
|
||||
torch.save((examples, all_answers, q_ids), cache_name)
|
||||
|
||||
|
||||
FIELD = data.Field(batch_first=True, use_vocab=False, sequential=False,
|
||||
|
@ -315,6 +316,7 @@ class SQuAD(CQA, data.Dataset):
|
|||
|
||||
super(SQuAD, self).__init__(examples, fields, **kwargs)
|
||||
self.all_answers = all_answers
|
||||
self.q_ids = q_ids
|
||||
|
||||
|
||||
@classmethod
|
||||
|
|
Loading…
Reference in New Issue