genienlp/decanlp/predict.py

294 lines
14 KiB
Python
Raw Normal View History

#
# Copyright (c) 2018, Salesforce, Inc.
# The Board of Trustees of the Leland Stanford Junior University
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2018-06-20 06:22:34 +00:00
import os
from .utils.generic_dataset import Query
from .text import torchtext
2018-06-20 06:22:34 +00:00
from argparse import ArgumentParser
import ujson as json
import torch
import numpy as np
import random
import sys
import logging
2018-06-20 06:22:34 +00:00
from pprint import pformat
from .util import set_seed, preprocess_examples, load_config_json
from .metrics import compute_metrics
from .utils.embeddings import load_embeddings
from .tasks.registry import get_tasks
from . import models
2018-06-20 06:22:34 +00:00
logger = logging.getLogger(__name__)
2018-06-20 06:22:34 +00:00
def get_all_splits(args, new_vocab):
splits = []
for task in args.tasks:
logger.info(f'Loading {task}')
2018-06-20 06:22:34 +00:00
kwargs = {}
if not 'train' in args.evaluate:
kwargs['train'] = None
if not 'valid' in args.evaluate:
kwargs['validation'] = None
2018-06-20 06:22:34 +00:00
if not 'test' in args.evaluate:
kwargs['test'] = None
kwargs['skip_cache_bool'] = args.skip_cache_bool
kwargs['cached_path'] = args.cached
s = task.get_splits(new_vocab, root=args.data, **kwargs)[0]
2018-06-20 06:22:34 +00:00
preprocess_examples(args, [task], [s], new_vocab, train=False)
splits.append(s)
return splits
def prepare_data(args, FIELD):
new_vocab = torchtext.data.ReversibleField(batch_first=True, init_token='<init>', eos_token='<eos>', lower=args.lower, include_lengths=True)
2018-06-20 06:22:34 +00:00
splits = get_all_splits(args, new_vocab)
new_vocab.build_vocab(*splits)
logger.info(f'Vocabulary has {len(FIELD.vocab)} tokens from training')
2018-06-20 06:22:34 +00:00
args.max_generative_vocab = min(len(FIELD.vocab), args.max_generative_vocab)
FIELD.append_vocab(new_vocab)
logger.info(f'Vocabulary has expanded to {len(FIELD.vocab)} tokens')
vectors = load_embeddings(args)
2018-06-20 06:22:34 +00:00
FIELD.vocab.load_vectors(vectors, True)
FIELD.decoder_to_vocab = {idx: FIELD.vocab.stoi[word] for idx, word in enumerate(FIELD.decoder_itos)}
FIELD.vocab_to_decoder = {idx: FIELD.decoder_stoi[word] for idx, word in enumerate(FIELD.vocab.itos) if word in FIELD.decoder_stoi}
splits = get_all_splits(args, FIELD)
return FIELD, splits
2018-09-18 00:30:36 +00:00
def to_iter(data, bs, device):
2018-06-20 06:22:34 +00:00
Iterator = torchtext.data.Iterator
it = Iterator(data, batch_size=bs,
2018-09-18 00:30:36 +00:00
device=device, batch_size_fn=None,
train=False, repeat=False, sort=False,
shuffle=False, reverse=False)
2018-06-20 06:22:34 +00:00
return it
def run(args, field, val_sets, model):
2018-09-18 00:30:36 +00:00
device = set_seed(args)
logger.info(f'Preparing iterators')
2018-09-27 20:08:55 +00:00
if len(args.val_batch_size) == 1 and len(val_sets) > 1:
args.val_batch_size *= len(val_sets)
2018-09-18 00:30:36 +00:00
iters = [(name, to_iter(x, bs, device)) for name, x, bs in zip(args.tasks, val_sets, args.val_batch_size)]
2018-06-20 06:22:34 +00:00
def mult(ps):
r = 0
for p in ps:
this_r = 1
for s in p.size():
this_r *= s
r += this_r
return r
params = list(filter(lambda p: p.requires_grad, model.parameters()))
num_param = mult(params)
logger.info(f'{args.model} has {num_param:,} parameters')
2018-09-18 00:30:36 +00:00
model.to(device)
2018-06-20 06:22:34 +00:00
2018-11-17 02:17:29 +00:00
decaScore = []
2018-06-20 06:22:34 +00:00
model.eval()
2018-09-18 00:30:36 +00:00
with torch.no_grad():
for task, it in iters:
logger.info(task.name)
if args.eval_dir:
prediction_file_name = os.path.join(args.eval_dir, os.path.join(args.evaluate, task.name + '.txt'))
answer_file_name = os.path.join(args.eval_dir, os.path.join(args.evaluate, task.name + '.gold.txt'))
results_file_name = answer_file_name.replace('gold', 'results')
else:
prediction_file_name = os.path.join(os.path.splitext(args.best_checkpoint)[0], args.evaluate, task.name + '.txt')
answer_file_name = os.path.join(os.path.splitext(args.best_checkpoint)[0], args.evaluate, task.name + '.gold.txt')
results_file_name = answer_file_name.replace('gold', 'results')
if 'sql' in task.name or 'squad' in task.name:
2018-09-18 00:30:36 +00:00
ids_file_name = answer_file_name.replace('gold', 'ids')
if os.path.exists(prediction_file_name):
logger.warning('** ', prediction_file_name, ' already exists -- this is where predictions are stored **')
2018-12-12 20:29:46 +00:00
if args.overwrite:
logger.warning('**** overwriting ', prediction_file_name, ' ****')
2018-09-18 00:30:36 +00:00
if os.path.exists(answer_file_name):
logger.warning('** ', answer_file_name, ' already exists -- this is where ground truth answers are stored **')
2018-12-12 20:29:46 +00:00
if args.overwrite:
logger.warning('**** overwriting ', answer_file_name, ' ****')
2018-09-18 00:30:36 +00:00
if os.path.exists(results_file_name):
logger.warning('** ', results_file_name, ' already exists -- this is where metrics are stored **')
2018-12-12 20:29:46 +00:00
if args.overwrite:
logger.warning('**** overwriting ', results_file_name, ' ****')
2018-12-12 20:29:46 +00:00
else:
2018-11-17 02:17:29 +00:00
with open(results_file_name) as results_file:
2018-12-12 20:29:46 +00:00
if not args.silent:
for l in results_file:
logger.debug(l)
2018-12-12 20:29:46 +00:00
metrics = json.loads(results_file.readlines()[0])
decaScore.append(metrics[task.metrics[0]])
2018-09-18 00:30:36 +00:00
continue
2018-11-17 02:17:29 +00:00
2018-09-18 00:30:36 +00:00
for x in [prediction_file_name, answer_file_name, results_file_name]:
os.makedirs(os.path.dirname(x), exist_ok=True)
2018-12-12 20:29:46 +00:00
if not os.path.exists(prediction_file_name) or args.overwrite:
2018-09-18 00:30:36 +00:00
with open(prediction_file_name, 'w') as prediction_file:
predictions = []
2018-09-27 20:08:55 +00:00
ids = []
2018-09-18 00:30:36 +00:00
for batch_idx, batch in enumerate(it):
2018-11-29 00:36:23 +00:00
_, p = model(batch, iteration=1)
p = field.reverse(p, detokenize=task.detokenize)
2018-09-18 00:30:36 +00:00
for i, pp in enumerate(p):
if 'sql' in task.name:
2018-09-27 20:08:55 +00:00
ids.append(int(batch.wikisql_id[i]))
if 'squad' in task.name:
2018-09-27 20:08:55 +00:00
ids.append(it.dataset.q_ids[int(batch.squad_id[i])])
2018-11-29 00:36:23 +00:00
prediction_file.write(json.dumps(pp) + '\n')
2018-09-18 00:30:36 +00:00
predictions.append(pp)
if 'sql' in task.name:
2018-11-17 02:17:29 +00:00
with open(ids_file_name, 'w') as id_file:
for i in ids:
id_file.write(json.dumps(i) + '\n')
if 'squad' in task.name:
2018-11-17 02:17:29 +00:00
with open(ids_file_name, 'w') as id_file:
for i in ids:
id_file.write(i + '\n')
2018-09-18 00:30:36 +00:00
else:
with open(prediction_file_name) as prediction_file:
predictions = [x.strip() for x in prediction_file.readlines()]
if 'sql' in task.name or 'squad' in task.name:
2018-11-17 02:17:29 +00:00
with open(ids_file_name) as id_file:
ids = [int(x.strip()) for x in id_file.readlines()]
2018-09-27 20:08:55 +00:00
2018-09-18 00:30:36 +00:00
def from_all_answers(an):
return [it.dataset.all_answers[sid] for sid in an.tolist()]
2018-12-12 20:29:46 +00:00
if not os.path.exists(answer_file_name) or args.overwrite:
2018-09-18 00:30:36 +00:00
with open(answer_file_name, 'w') as answer_file:
answers = []
for batch_idx, batch in enumerate(it):
if hasattr(batch, 'wikisql_id'):
a = from_all_answers(batch.wikisql_id.data.cpu())
elif hasattr(batch, 'squad_id'):
a = from_all_answers(batch.squad_id.data.cpu())
elif hasattr(batch, 'woz_id'):
a = from_all_answers(batch.woz_id.data.cpu())
else:
a = field.reverse(batch.answer.data, detokenize=task.detokenize)
2018-09-18 00:30:36 +00:00
for aa in a:
answers.append(aa)
answer_file.write(json.dumps(aa) + '\n')
else:
with open(answer_file_name) as answer_file:
answers = [json.loads(x.strip()) for x in answer_file.readlines()]
if len(answers) > 0:
2018-12-12 20:29:46 +00:00
if not os.path.exists(results_file_name) or args.overwrite:
metrics, answers = compute_metrics(predictions, answers, task.metrics, args=args)
2018-09-18 00:30:36 +00:00
with open(results_file_name, 'w') as results_file:
results_file.write(json.dumps(metrics) + '\n')
else:
with open(results_file_name) as results_file:
metrics = json.loads(results_file.readlines()[0])
if not args.silent:
for i, (p, a) in enumerate(zip(predictions, answers)):
logger.info(f'Prediction {i+1}: {p}\nAnswer {i+1}: {a}\n')
logger.info(metrics)
decaScore.append(metrics[task.metrics[0]])
2018-12-12 20:29:46 +00:00
logger.info(f'Evaluated Tasks:\n')
2018-11-17 02:17:29 +00:00
for i, (task, _) in enumerate(iters):
logger.info(f'{task.name}: {decaScore[i]}')
logger.info(f'-------------------')
logger.info(f'DecaScore: {sum(decaScore)}\n')
logger.info(f'\nSummary: | {sum(decaScore)} | {" | ".join([str(x) for x in decaScore])} |\n')
2018-06-20 06:22:34 +00:00
def get_args(argv):
parser = ArgumentParser(prog=argv[0])
2018-06-20 06:22:34 +00:00
parser.add_argument('--path', required=True)
parser.add_argument('--evaluate', type=str, required=True)
parser.add_argument('--tasks', default=['almond', 'squad', 'iwslt.en.de', 'cnn_dailymail', 'multinli.in.out', 'sst', 'srl', 'zre', 'woz.en', 'wikisql', 'schema'], dest='task_names', nargs='+')
2018-10-23 23:21:26 +00:00
parser.add_argument('--devices', default=[0], nargs='+', type=int, help='a list of devices that can be used (multi-gpu currently WIP)')
2018-06-20 06:22:34 +00:00
parser.add_argument('--seed', default=123, type=int, help='Random seed.')
parser.add_argument('--data', default='./decaNLP/.data/', type=str, help='where to load data from.')
parser.add_argument('--embeddings', default='./decaNLP/.embeddings', type=str, help='where to save embeddings.')
parser.add_argument('--checkpoint_name', default='best.pth', help='Checkpoint file to use (relative to --path, defaults to best.pth)')
parser.add_argument('--bleu', action='store_true', help='whether to use the bleu metric (always on for iwslt)')
parser.add_argument('--rouge', action='store_true', help='whether to use the bleu metric (always on for cnn, dailymail, and cnn_dailymail)')
2018-12-12 20:29:46 +00:00
parser.add_argument('--overwrite', action='store_true', help='whether to overwrite previously written predictions')
parser.add_argument('--silent', action='store_true', help='whether to print predictions to stdout')
2018-06-27 18:52:02 +00:00
parser.add_argument('--skip_cache', action='store_true', dest='skip_cache_bool', help='whether use exisiting cached splits or generate new ones')
parser.add_argument('--eval_dir', type=str, default=None, help='use this directory to store eval results')
2019-02-20 19:22:32 +00:00
parser.add_argument('--cached', default='', type=str, help='where to save cached files')
2019-03-02 00:13:10 +00:00
args = parser.parse_args(argv[1:])
args.tasks = get_tasks(args.task_names)
2018-06-20 06:22:34 +00:00
load_config_json(args)
2018-06-20 06:22:34 +00:00
return args
def main(argv=sys.argv):
2019-03-02 00:13:10 +00:00
args = get_args(argv)
logger.info(f'Arguments:\n{pformat(vars(args))}')
2018-06-20 06:22:34 +00:00
np.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
logger.info(f'Loading from {args.best_checkpoint}')
if torch.cuda.is_available():
save_dict = torch.load(args.best_checkpoint)
else:
save_dict = torch.load(args.best_checkpoint, map_location='cpu')
2018-06-20 06:22:34 +00:00
field = save_dict['field']
logger.info(f'Initializing Model')
2019-02-27 18:54:01 +00:00
Model = getattr(models, args.model)
2018-06-20 06:22:34 +00:00
model = Model(field, args)
model_dict = save_dict['model_state_dict']
backwards_compatible_cove_dict = {}
for k, v in model_dict.items():
if 'cove.rnn.' in k:
k = k.replace('cove.rnn.', 'cove.rnn1.')
backwards_compatible_cove_dict[k] = v
model_dict = backwards_compatible_cove_dict
model.load_state_dict(model_dict)
2018-06-20 06:22:34 +00:00
field, splits = prepare_data(args, field)
model.set_embeddings(field.vocab.vectors)
run(args, field, splits, model)
if __name__ == '__main__':
main()