From f762fd83c58acac3e9f540f509a758c0e2961278 Mon Sep 17 00:00:00 2001 From: Bryan Marcus McCann Date: Sat, 25 Aug 2018 00:53:01 +0000 Subject: [PATCH] squad2.0 --- text/torchtext/datasets/generic.py | 123 ++++++++++++++++------------- util.py | 2 +- 2 files changed, 68 insertions(+), 57 deletions(-) diff --git a/text/torchtext/datasets/generic.py b/text/torchtext/datasets/generic.py index 225447e7..d9b1227e 100644 --- a/text/torchtext/datasets/generic.py +++ b/text/torchtext/datasets/generic.py @@ -195,7 +195,9 @@ class SQuAD(CQA, data.Dataset): return data.interleave_keys(len(ex.context), len(ex.answer)) urls = ['https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json', - 'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json'] + 'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json', + 'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json', + 'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json',] name = 'squad' dirname = '' @@ -217,63 +219,72 @@ class SQuAD(CQA, data.Dataset): qas = paragraph['qas'] for qa in qas: question = ' '.join(qa['question'].split()) - answer = qa['answers'][0]['text'] squad_id = len(all_answers) - all_answers.append([a['text'] for a in qa['answers']]) - #print('original: ', answer) - answer_start = qa['answers'][0]['answer_start'] - answer_end = answer_start + len(answer) - context_before_answer = context[:answer_start] - context_after_answer = context[answer_end:] - BEGIN = 'beginanswer ' - END = ' endanswer' - - tagged_context = context_before_answer + BEGIN + answer + END + context_after_answer context_question = get_context_question(context, question) - ex = data.Example.fromlist([tagged_context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) + if len(qa['answers']) == 0: + answer = 'unanswerable' + all_answers.append(['unanswerable']) + context = ' '.join(context.split()) + ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) + ex.context_spans = [-1, -1] + ex.answer_start = -1 + ex.answer_end = -1 + else: + answer = qa['answers'][0]['text'] + all_answers.append([a['text'] for a in qa['answers']]) + #print('original: ', answer) + answer_start = qa['answers'][0]['answer_start'] + answer_end = answer_start + len(answer) + context_before_answer = context[:answer_start] + context_after_answer = context[answer_end:] + BEGIN = 'beginanswer ' + END = ' endanswer' - tokenized_answer = ex.answer - #print('tokenized: ', tokenized_answer) - for xi, x in enumerate(ex.context): - if BEGIN in x: - answer_start = xi + 1 - ex.context[xi] = x.replace(BEGIN, '') - if END in x: - answer_end = xi - ex.context[xi] = x.replace(END, '') - new_context = [] - original_answer_start = answer_start - original_answer_end = answer_end - indexed_with_spaces = ex.context[answer_start:answer_end] - if len(indexed_with_spaces) != len(tokenized_answer): - import pdb; pdb.set_trace() + tagged_context = context_before_answer + BEGIN + answer + END + context_after_answer + ex = data.Example.fromlist([tagged_context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) - # remove spaces - for xi, x in enumerate(ex.context): - if len(x.strip()) == 0: - if xi <= original_answer_start: - answer_start -= 1 - if xi < original_answer_end: - answer_end -= 1 - else: - new_context.append(x) - ex.context = new_context - ex.answer = [x for x in ex.answer if len(x.strip()) > 0] - if len(ex.context[answer_start:answer_end]) != len(ex.answer): - import pdb; pdb.set_trace() - ex.context_spans = list(range(answer_start, answer_end)) - indexed_answer = ex.context[ex.context_spans[0]:ex.context_spans[-1]+1] - if len(indexed_answer) != len(ex.answer): - import pdb; pdb.set_trace() - if field.eos_token is not None: - ex.context_spans += [len(ex.context)] - for context_idx, answer_word in zip(ex.context_spans, ex.answer): - if context_idx == len(ex.context): - continue - if ex.context[context_idx] != answer_word: + tokenized_answer = ex.answer + #print('tokenized: ', tokenized_answer) + for xi, x in enumerate(ex.context): + if BEGIN in x: + answer_start = xi + 1 + ex.context[xi] = x.replace(BEGIN, '') + if END in x: + answer_end = xi + ex.context[xi] = x.replace(END, '') + new_context = [] + original_answer_start = answer_start + original_answer_end = answer_end + indexed_with_spaces = ex.context[answer_start:answer_end] + if len(indexed_with_spaces) != len(tokenized_answer): import pdb; pdb.set_trace() - ex.answer_start = ex.context_spans[0] - ex.answer_end = ex.context_spans[-1] + + # remove spaces + for xi, x in enumerate(ex.context): + if len(x.strip()) == 0: + if xi <= original_answer_start: + answer_start -= 1 + if xi < original_answer_end: + answer_end -= 1 + else: + new_context.append(x) + ex.context = new_context + ex.answer = [x for x in ex.answer if len(x.strip()) > 0] + if len(ex.context[answer_start:answer_end]) != len(ex.answer): + import pdb; pdb.set_trace() + ex.context_spans = list(range(answer_start, answer_end)) + indexed_answer = ex.context[ex.context_spans[0]:ex.context_spans[-1]+1] + if len(indexed_answer) != len(ex.answer): + import pdb; pdb.set_trace() + if field.eos_token is not None: + ex.context_spans += [len(ex.context)] + for context_idx, answer_word in zip(ex.context_spans, ex.answer): + if context_idx == len(ex.context): + continue + if ex.context[context_idx] != answer_word: + import pdb; pdb.set_trace() + ex.answer_start = ex.context_spans[0] + ex.answer_end = ex.context_spans[-1] ex.squad_id = squad_id examples.append(ex) if subsample is not None and len(examples) > subsample: @@ -299,7 +310,7 @@ class SQuAD(CQA, data.Dataset): @classmethod - def splits(cls, fields, root='.data', + def splits(cls, fields, root='.data', description='squad1.1', train='train', validation='dev', test=None, **kwargs): """Create dataset objects for splits of the SQuAD dataset. Arguments: @@ -313,7 +324,7 @@ class SQuAD(CQA, data.Dataset): assert test is None path = cls.download(root) - extension = 'v1.1.json' + extension = 'v2.0.json' if '2.0' in description else 'v1.1.json' train = '-'.join([train, extension]) if train is not None else None validation = '-'.join([validation, extension]) if validation is not None else None @@ -1104,7 +1115,7 @@ class ZeroShotRE(CQA, data.Dataset): question = question.replace('XXX', subject) ex = {'context': context, 'question': question, - 'answer': answer if len(answer) > 0 else 'Unanswerable'} + 'answer': answer if len(answer) > 0 else 'unanswerable'} split_file.write(json.dumps(ex)+'\n') diff --git a/util.py b/util.py index 695b9557..2db12ed5 100644 --- a/util.py +++ b/util.py @@ -123,7 +123,7 @@ def get_splits(args, task, FIELD, **kwargs): fields=FIELD, root=args.data, **kwargs) elif 'squad' in task: split = torchtext.datasets.generic.SQuAD.splits( - fields=FIELD, root=args.data, **kwargs) + fields=FIELD, root=args.data, description=task, **kwargs) elif task == 'wikisql': split = torchtext.datasets.generic.WikiSQL.splits( fields=FIELD, root=args.data, **kwargs)