Updating predict for easier custom inference
This commit is contained in:
parent
f4d2e91d93
commit
d594b2bf12
34
predict.py
34
predict.py
|
@ -88,21 +88,21 @@ def run(args, field, val_sets, model):
|
|||
if 'sql' in task:
|
||||
ids_file_name = answer_file_name.replace('gold', 'ids')
|
||||
if os.path.exists(prediction_file_name):
|
||||
print('** ', prediction_file_name, ' already exists**')
|
||||
print('** ', prediction_file_name, ' already exists -- this is where predictions are stored **')
|
||||
if os.path.exists(answer_file_name):
|
||||
print('** ', answer_file_name, ' already exists**')
|
||||
print('** ', answer_file_name, ' already exists -- this is where ground truth answers are stored **')
|
||||
if os.path.exists(results_file_name):
|
||||
print('** ', results_file_name, ' already exists**')
|
||||
print('** ', results_file_name, ' already exists -- this is where metrics are stored **')
|
||||
with open(results_file_name) as results_file:
|
||||
for l in results_file:
|
||||
print(l)
|
||||
if not 'schema' in task:
|
||||
if not 'schema' in task and not args.overwrite_predictions and args.silent:
|
||||
continue
|
||||
for x in [prediction_file_name, answer_file_name, results_file_name]:
|
||||
os.makedirs(os.path.dirname(x), exist_ok=True)
|
||||
|
||||
if not os.path.exists(prediction_file_name):
|
||||
with open(prediction_file_name, 'a') as prediction_file:
|
||||
if not os.path.exists(prediction_file_name) or args.overwrite_predictions:
|
||||
with open(prediction_file_name, 'w') as prediction_file:
|
||||
predictions = []
|
||||
wikisql_ids = []
|
||||
for batch_idx, batch in enumerate(it):
|
||||
|
@ -127,7 +127,7 @@ def run(args, field, val_sets, model):
|
|||
return [it.dataset.all_answers[sid] for sid in an.tolist()]
|
||||
|
||||
if not os.path.exists(answer_file_name):
|
||||
with open(answer_file_name, 'a') as answer_file:
|
||||
with open(answer_file_name, 'w') as answer_file:
|
||||
answers = []
|
||||
for batch_idx, batch in enumerate(it):
|
||||
if hasattr(batch, 'wikisql_id'):
|
||||
|
@ -146,12 +146,19 @@ def run(args, field, val_sets, model):
|
|||
answers = [json.loads(x.strip()) for x in answer_file.readlines()]
|
||||
|
||||
if len(answers) > 0:
|
||||
metrics, answers = compute_metrics(predictions, answers, bleu='iwslt' in task or 'multi30k' in task, dialogue='woz' in task,
|
||||
rouge='cnn' in task, logical_form='sql' in task, corpus_f1='zre' in task, args=args)
|
||||
|
||||
print(metrics)
|
||||
if not os.path.exists(results_file_name):
|
||||
metrics, answers = compute_metrics(predictions, answers, bleu='iwslt' in task or 'multi30k' in task or args.bleu, dialogue='woz' in task,
|
||||
rouge='cnn' in task or 'dailymail' in task or args.rouge, logical_form='sql' in task, corpus_f1='zre' in task, args=args)
|
||||
with open(results_file_name, 'w') as results_file:
|
||||
results_file.write(json.dumps(metrics) + '\n')
|
||||
else:
|
||||
with open(results_file_name) as results_file:
|
||||
metrics = json.loads(results_file.readlines()[0])
|
||||
|
||||
print(metrics)
|
||||
if not args.silent:
|
||||
for p, a in zip(predictions, answers):
|
||||
print(f'Prediction: {p}\nAnswer: {a}\n')
|
||||
|
||||
|
||||
def get_args():
|
||||
|
@ -164,6 +171,10 @@ def get_args():
|
|||
parser.add_argument('--data', default='/decaNLP/.data/', type=str, help='where to load data from.')
|
||||
parser.add_argument('--embeddings', default='/decaNLP/.embeddings', type=str, help='where to save embeddings.')
|
||||
parser.add_argument('--checkpoint_name')
|
||||
parser.add_argument('--bleu', action='store_true', help='whether to use the bleu metric (always on for iwslt)')
|
||||
parser.add_argument('--rouge', action='store_true', help='whether to use the bleu metric (always on for cnn, dailymail, and cnn_dailymail)')
|
||||
parser.add_argument('--overwrite_predictions', action='store_true', help='whether to overwrite previously written predictions')
|
||||
parser.add_argument('--silent', action='store_true', help='whether to print predictions to stdout')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -236,6 +247,7 @@ def get_best(args):
|
|||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
print(f'Arguments:\n{pformat(vars(args))}')
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = f'{args.gpus}'
|
||||
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
|
|
Loading…
Reference in New Issue