This commit is contained in:
Bryan Marcus McCann 2018-08-25 00:53:01 +00:00
parent 1c5b5f3bdf
commit f762fd83c5
2 changed files with 68 additions and 57 deletions

View File

@ -195,7 +195,9 @@ class SQuAD(CQA, data.Dataset):
return data.interleave_keys(len(ex.context), len(ex.answer))
urls = ['https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json',
'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json']
'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json',
'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json',
'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json',]
name = 'squad'
dirname = ''
@ -217,8 +219,18 @@ class SQuAD(CQA, data.Dataset):
qas = paragraph['qas']
for qa in qas:
question = ' '.join(qa['question'].split())
answer = qa['answers'][0]['text']
squad_id = len(all_answers)
context_question = get_context_question(context, question)
if len(qa['answers']) == 0:
answer = 'unanswerable'
all_answers.append(['unanswerable'])
context = ' '.join(context.split())
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
ex.context_spans = [-1, -1]
ex.answer_start = -1
ex.answer_end = -1
else:
answer = qa['answers'][0]['text']
all_answers.append([a['text'] for a in qa['answers']])
#print('original: ', answer)
answer_start = qa['answers'][0]['answer_start']
@ -229,7 +241,6 @@ class SQuAD(CQA, data.Dataset):
END = ' endanswer'
tagged_context = context_before_answer + BEGIN + answer + END + context_after_answer
context_question = get_context_question(context, question)
ex = data.Example.fromlist([tagged_context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
tokenized_answer = ex.answer
@ -299,7 +310,7 @@ class SQuAD(CQA, data.Dataset):
@classmethod
def splits(cls, fields, root='.data',
def splits(cls, fields, root='.data', description='squad1.1',
train='train', validation='dev', test=None, **kwargs):
"""Create dataset objects for splits of the SQuAD dataset.
Arguments:
@ -313,7 +324,7 @@ class SQuAD(CQA, data.Dataset):
assert test is None
path = cls.download(root)
extension = 'v1.1.json'
extension = 'v2.0.json' if '2.0' in description else 'v1.1.json'
train = '-'.join([train, extension]) if train is not None else None
validation = '-'.join([validation, extension]) if validation is not None else None
@ -1104,7 +1115,7 @@ class ZeroShotRE(CQA, data.Dataset):
question = question.replace('XXX', subject)
ex = {'context': context,
'question': question,
'answer': answer if len(answer) > 0 else 'Unanswerable'}
'answer': answer if len(answer) > 0 else 'unanswerable'}
split_file.write(json.dumps(ex)+'\n')

View File

@ -123,7 +123,7 @@ def get_splits(args, task, FIELD, **kwargs):
fields=FIELD, root=args.data, **kwargs)
elif 'squad' in task:
split = torchtext.datasets.generic.SQuAD.splits(
fields=FIELD, root=args.data, **kwargs)
fields=FIELD, root=args.data, description=task, **kwargs)
elif task == 'wikisql':
split = torchtext.datasets.generic.WikiSQL.splits(
fields=FIELD, root=args.data, **kwargs)