fixing build for predict
This commit is contained in:
parent
a35c7b8f5c
commit
d776d9f600
|
@ -104,13 +104,13 @@ def run(args, field, val_sets, model):
|
|||
if not os.path.exists(prediction_file_name):
|
||||
with open(prediction_file_name, 'a') as prediction_file:
|
||||
predictions = []
|
||||
wikisql_ids = []
|
||||
for batch_idx, batch in enumerate(it):
|
||||
_, p = model(batch)
|
||||
p = field.reverse(p)
|
||||
for i, pp in enumerate(p):
|
||||
if 'sql' in task:
|
||||
wikisql_id = int(batch.wikisql_id[i])
|
||||
wikisql_ids = wikisql_ids if wikisql_ids is not None else []
|
||||
wikisql_ids.append(wikisql_id)
|
||||
prediction_file.write(pp + '\n')
|
||||
predictions.append(pp)
|
||||
|
|
Loading…
Reference in New Issue