fixing build for predict

This commit is contained in:
Bryan Marcus McCann 2018-08-16 19:51:40 +00:00
parent a35c7b8f5c
commit d776d9f600
1 changed files with 1 additions and 1 deletions

View File

@ -104,13 +104,13 @@ def run(args, field, val_sets, model):
if not os.path.exists(prediction_file_name):
with open(prediction_file_name, 'a') as prediction_file:
predictions = []
wikisql_ids = []
for batch_idx, batch in enumerate(it):
_, p = model(batch)
p = field.reverse(p)
for i, pp in enumerate(p):
if 'sql' in task:
wikisql_id = int(batch.wikisql_id[i])
wikisql_ids = wikisql_ids if wikisql_ids is not None else []
wikisql_ids.append(wikisql_id)
prediction_file.write(pp + '\n')
predictions.append(pp)