diff --git a/predict.py b/predict.py index d2a82bc9..0d900f66 100644 --- a/predict.py +++ b/predict.py @@ -64,6 +64,8 @@ def to_iter(data, bs, device): def run(args, field, val_sets, model): device = set_seed(args) print(f'Preparing iterators') + if len(args.val_batch_size) == 1 and len(val_sets) > 1: + args.val_batch_size *= len(val_sets) iters = [(name, to_iter(x, bs, device)) for name, x, bs in zip(args.tasks, val_sets, args.val_batch_size)] def mult(ps): @@ -82,10 +84,11 @@ def run(args, field, val_sets, model): model.eval() with torch.no_grad(): for task, it in iters: + print(task) prediction_file_name = os.path.join(os.path.splitext(args.best_checkpoint)[0], args.evaluate, task + '.txt') answer_file_name = os.path.join(os.path.splitext(args.best_checkpoint)[0], args.evaluate, task + '.gold.txt') results_file_name = answer_file_name.replace('gold', 'results') - if 'sql' in task: + if 'sql' in task or 'squad' in task: ids_file_name = answer_file_name.replace('gold', 'ids') if os.path.exists(prediction_file_name): print('** ', prediction_file_name, ' already exists -- this is where predictions are stored **') @@ -104,14 +107,15 @@ def run(args, field, val_sets, model): if not os.path.exists(prediction_file_name) or args.overwrite_predictions: with open(prediction_file_name, 'w') as prediction_file: predictions = [] - wikisql_ids = [] + ids = [] for batch_idx, batch in enumerate(it): _, p = model(batch) p = field.reverse(p) for i, pp in enumerate(p): if 'sql' in task: - wikisql_id = int(batch.wikisql_id[i]) - wikisql_ids.append(wikisql_id) + ids.append(int(batch.wikisql_id[i])) + if 'squad' in task: + ids.append(it.dataset.q_ids[int(batch.squad_id[i])]) prediction_file.write(pp + '\n') predictions.append(pp) else: @@ -120,9 +124,14 @@ def run(args, field, val_sets, model): if 'sql' in task: with open(ids_file_name, 'w') as id_file: - for i in wikisql_ids: + for i in ids: id_file.write(json.dumps(i) + '\n') - + + if 'squad' in task: + with open(ids_file_name, 'w') as id_file: + for i in ids: + id_file.write(i + '\n') + def from_all_answers(an): return [it.dataset.all_answers[sid] for sid in an.tolist()] @@ -165,7 +174,7 @@ def get_args(): parser = ArgumentParser() parser.add_argument('--path', required=True) parser.add_argument('--evaluate', type=str, required=True) - parser.add_argument('--tasks', default=['wikisql', 'woz.en', 'cnn_dailymail', 'iwslt.en.de', 'zre', 'srl', 'squad', 'sst', 'multinli.in.out', 'schema'], nargs='+') + parser.add_argument('--tasks', default=['squad', 'iwslt.en.de', 'cnn_dailymail', 'multinli.in.out', 'sst', 'srl', 'zre', 'woz.en', 'wikisql', 'schema'], nargs='+') parser.add_argument('--gpus', default=[0], nargs='+', type=int, help='a list of gpus that can be used (multi-gpu currently WIP)') parser.add_argument('--seed', default=123, type=int, help='Random seed.') parser.add_argument('--data', default='/decaNLP/.data/', type=str, help='where to load data from.') @@ -180,7 +189,7 @@ def get_args(): with open(os.path.join(args.path, 'config.json')) as config_file: config = json.load(config_file) - retrieve = ['model', 'val_batch_size', + retrieve = ['model', 'transformer_layers', 'rnn_layers', 'transformer_hidden', 'dimension', 'load', 'max_val_context_length', 'val_batch_size', 'transformer_heads', 'max_output_length', 'max_generative_vocab', diff --git a/text/torchtext/datasets/generic.py b/text/torchtext/datasets/generic.py index 250cc7ba..caee9e10 100644 --- a/text/torchtext/datasets/generic.py +++ b/text/torchtext/datasets/generic.py @@ -211,10 +211,10 @@ class SQuAD(CQA, data.Dataset): fields = [(x, field) for x in self.fields] cache_name = os.path.join(os.path.dirname(path), '.cache', os.path.basename(path), str(subsample)) - examples, all_answers = [], [] + examples, all_answers, q_ids = [], [], [] if os.path.exists(cache_name): print(f'Loading cached data from {cache_name}') - examples, all_answers = torch.load(cache_name) + examples, all_answers, q_ids = torch.load(cache_name) else: with open(os.path.expanduser(path)) as f: squad = json.load(f)['data'] @@ -226,6 +226,7 @@ class SQuAD(CQA, data.Dataset): qas = paragraph['qas'] for qa in qas: question = ' '.join(qa['question'].split()) + q_ids.append(qa['id']) squad_id = len(all_answers) context_question = get_context_question(context, question) if len(qa['answers']) == 0: @@ -303,7 +304,7 @@ class SQuAD(CQA, data.Dataset): os.makedirs(os.path.dirname(cache_name), exist_ok=True) print(f'Caching data to {cache_name}') - torch.save((examples, all_answers), cache_name) + torch.save((examples, all_answers, q_ids), cache_name) FIELD = data.Field(batch_first=True, use_vocab=False, sequential=False, @@ -315,6 +316,7 @@ class SQuAD(CQA, data.Dataset): super(SQuAD, self).__init__(examples, fields, **kwargs) self.all_answers = all_answers + self.q_ids = q_ids @classmethod