fixing build for predict
This commit is contained in:
parent
a35c7b8f5c
commit
d776d9f600
|
@ -104,13 +104,13 @@ def run(args, field, val_sets, model):
|
||||||
if not os.path.exists(prediction_file_name):
|
if not os.path.exists(prediction_file_name):
|
||||||
with open(prediction_file_name, 'a') as prediction_file:
|
with open(prediction_file_name, 'a') as prediction_file:
|
||||||
predictions = []
|
predictions = []
|
||||||
|
wikisql_ids = []
|
||||||
for batch_idx, batch in enumerate(it):
|
for batch_idx, batch in enumerate(it):
|
||||||
_, p = model(batch)
|
_, p = model(batch)
|
||||||
p = field.reverse(p)
|
p = field.reverse(p)
|
||||||
for i, pp in enumerate(p):
|
for i, pp in enumerate(p):
|
||||||
if 'sql' in task:
|
if 'sql' in task:
|
||||||
wikisql_id = int(batch.wikisql_id[i])
|
wikisql_id = int(batch.wikisql_id[i])
|
||||||
wikisql_ids = wikisql_ids if wikisql_ids is not None else []
|
|
||||||
wikisql_ids.append(wikisql_id)
|
wikisql_ids.append(wikisql_id)
|
||||||
prediction_file.write(pp + '\n')
|
prediction_file.write(pp + '\n')
|
||||||
predictions.append(pp)
|
predictions.append(pp)
|
||||||
|
|
Loading…
Reference in New Issue