diff --git a/predict.py b/predict.py index 541ed2d9..927e6ac0 100644 --- a/predict.py +++ b/predict.py @@ -110,6 +110,7 @@ def run(args, field, val_sets, model): for i, pp in enumerate(p): if 'sql' in task: wikisql_id = int(batch.wikisql_id[i]) + wikisql_ids = wikisql_ids if wikisql_ids is not None else [] wikisql_ids.append(wikisql_id) prediction_file.write(pp + '\n') predictions.append(pp)