squad2.0
This commit is contained in:
parent
1c5b5f3bdf
commit
f762fd83c5
|
@ -195,7 +195,9 @@ class SQuAD(CQA, data.Dataset):
|
|||
return data.interleave_keys(len(ex.context), len(ex.answer))
|
||||
|
||||
urls = ['https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json',
|
||||
'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json']
|
||||
'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json',
|
||||
'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json',
|
||||
'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json',]
|
||||
name = 'squad'
|
||||
dirname = ''
|
||||
|
||||
|
@ -217,8 +219,18 @@ class SQuAD(CQA, data.Dataset):
|
|||
qas = paragraph['qas']
|
||||
for qa in qas:
|
||||
question = ' '.join(qa['question'].split())
|
||||
answer = qa['answers'][0]['text']
|
||||
squad_id = len(all_answers)
|
||||
context_question = get_context_question(context, question)
|
||||
if len(qa['answers']) == 0:
|
||||
answer = 'unanswerable'
|
||||
all_answers.append(['unanswerable'])
|
||||
context = ' '.join(context.split())
|
||||
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
|
||||
ex.context_spans = [-1, -1]
|
||||
ex.answer_start = -1
|
||||
ex.answer_end = -1
|
||||
else:
|
||||
answer = qa['answers'][0]['text']
|
||||
all_answers.append([a['text'] for a in qa['answers']])
|
||||
#print('original: ', answer)
|
||||
answer_start = qa['answers'][0]['answer_start']
|
||||
|
@ -229,7 +241,6 @@ class SQuAD(CQA, data.Dataset):
|
|||
END = ' endanswer'
|
||||
|
||||
tagged_context = context_before_answer + BEGIN + answer + END + context_after_answer
|
||||
context_question = get_context_question(context, question)
|
||||
ex = data.Example.fromlist([tagged_context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
|
||||
|
||||
tokenized_answer = ex.answer
|
||||
|
@ -299,7 +310,7 @@ class SQuAD(CQA, data.Dataset):
|
|||
|
||||
|
||||
@classmethod
|
||||
def splits(cls, fields, root='.data',
|
||||
def splits(cls, fields, root='.data', description='squad1.1',
|
||||
train='train', validation='dev', test=None, **kwargs):
|
||||
"""Create dataset objects for splits of the SQuAD dataset.
|
||||
Arguments:
|
||||
|
@ -313,7 +324,7 @@ class SQuAD(CQA, data.Dataset):
|
|||
assert test is None
|
||||
path = cls.download(root)
|
||||
|
||||
extension = 'v1.1.json'
|
||||
extension = 'v2.0.json' if '2.0' in description else 'v1.1.json'
|
||||
train = '-'.join([train, extension]) if train is not None else None
|
||||
validation = '-'.join([validation, extension]) if validation is not None else None
|
||||
|
||||
|
@ -1104,7 +1115,7 @@ class ZeroShotRE(CQA, data.Dataset):
|
|||
question = question.replace('XXX', subject)
|
||||
ex = {'context': context,
|
||||
'question': question,
|
||||
'answer': answer if len(answer) > 0 else 'Unanswerable'}
|
||||
'answer': answer if len(answer) > 0 else 'unanswerable'}
|
||||
split_file.write(json.dumps(ex)+'\n')
|
||||
|
||||
|
||||
|
|
2
util.py
2
util.py
|
@ -123,7 +123,7 @@ def get_splits(args, task, FIELD, **kwargs):
|
|||
fields=FIELD, root=args.data, **kwargs)
|
||||
elif 'squad' in task:
|
||||
split = torchtext.datasets.generic.SQuAD.splits(
|
||||
fields=FIELD, root=args.data, **kwargs)
|
||||
fields=FIELD, root=args.data, description=task, **kwargs)
|
||||
elif task == 'wikisql':
|
||||
split = torchtext.datasets.generic.WikiSQL.splits(
|
||||
fields=FIELD, root=args.data, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue