From d594b2bf127e13d0e61151b6a2af3bf63612f380 Mon Sep 17 00:00:00 2001 From: Bryan Marcus McCann Date: Fri, 31 Aug 2018 01:06:09 +0000 Subject: [PATCH] Updating predict for easier custom inference --- predict.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/predict.py b/predict.py index 0920ae09..ca608a02 100644 --- a/predict.py +++ b/predict.py @@ -88,21 +88,21 @@ def run(args, field, val_sets, model): if 'sql' in task: ids_file_name = answer_file_name.replace('gold', 'ids') if os.path.exists(prediction_file_name): - print('** ', prediction_file_name, ' already exists**') + print('** ', prediction_file_name, ' already exists -- this is where predictions are stored **') if os.path.exists(answer_file_name): - print('** ', answer_file_name, ' already exists**') + print('** ', answer_file_name, ' already exists -- this is where ground truth answers are stored **') if os.path.exists(results_file_name): - print('** ', results_file_name, ' already exists**') + print('** ', results_file_name, ' already exists -- this is where metrics are stored **') with open(results_file_name) as results_file: for l in results_file: print(l) - if not 'schema' in task: + if not 'schema' in task and not args.overwrite_predictions and args.silent: continue for x in [prediction_file_name, answer_file_name, results_file_name]: os.makedirs(os.path.dirname(x), exist_ok=True) - if not os.path.exists(prediction_file_name): - with open(prediction_file_name, 'a') as prediction_file: + if not os.path.exists(prediction_file_name) or args.overwrite_predictions: + with open(prediction_file_name, 'w') as prediction_file: predictions = [] wikisql_ids = [] for batch_idx, batch in enumerate(it): @@ -127,7 +127,7 @@ def run(args, field, val_sets, model): return [it.dataset.all_answers[sid] for sid in an.tolist()] if not os.path.exists(answer_file_name): - with open(answer_file_name, 'a') as answer_file: + with open(answer_file_name, 'w') as answer_file: answers = [] for batch_idx, batch in enumerate(it): if hasattr(batch, 'wikisql_id'): @@ -146,12 +146,19 @@ def run(args, field, val_sets, model): answers = [json.loads(x.strip()) for x in answer_file.readlines()] if len(answers) > 0: - metrics, answers = compute_metrics(predictions, answers, bleu='iwslt' in task or 'multi30k' in task, dialogue='woz' in task, - rouge='cnn' in task, logical_form='sql' in task, corpus_f1='zre' in task, args=args) + if not os.path.exists(results_file_name): + metrics, answers = compute_metrics(predictions, answers, bleu='iwslt' in task or 'multi30k' in task or args.bleu, dialogue='woz' in task, + rouge='cnn' in task or 'dailymail' in task or args.rouge, logical_form='sql' in task, corpus_f1='zre' in task, args=args) + with open(results_file_name, 'w') as results_file: + results_file.write(json.dumps(metrics) + '\n') + else: + with open(results_file_name) as results_file: + metrics = json.loads(results_file.readlines()[0]) print(metrics) - with open(results_file_name, 'w') as results_file: - results_file.write(json.dumps(metrics) + '\n') + if not args.silent: + for p, a in zip(predictions, answers): + print(f'Prediction: {p}\nAnswer: {a}\n') def get_args(): @@ -164,6 +171,10 @@ def get_args(): parser.add_argument('--data', default='/decaNLP/.data/', type=str, help='where to load data from.') parser.add_argument('--embeddings', default='/decaNLP/.embeddings', type=str, help='where to save embeddings.') parser.add_argument('--checkpoint_name') + parser.add_argument('--bleu', action='store_true', help='whether to use the bleu metric (always on for iwslt)') + parser.add_argument('--rouge', action='store_true', help='whether to use the bleu metric (always on for cnn, dailymail, and cnn_dailymail)') + parser.add_argument('--overwrite_predictions', action='store_true', help='whether to overwrite previously written predictions') + parser.add_argument('--silent', action='store_true', help='whether to print predictions to stdout') args = parser.parse_args() @@ -236,6 +247,7 @@ def get_best(args): if __name__ == '__main__': args = get_args() print(f'Arguments:\n{pformat(vars(args))}') + os.environ['CUDA_VISIBLE_DEVICES'] = f'{args.gpus}' np.random.seed(args.seed) random.seed(args.seed)