squad2.0
This commit is contained in:
parent
1c5b5f3bdf
commit
f762fd83c5
|
@ -195,7 +195,9 @@ class SQuAD(CQA, data.Dataset):
|
|||
return data.interleave_keys(len(ex.context), len(ex.answer))
|
||||
|
||||
urls = ['https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json',
|
||||
'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json']
|
||||
'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json',
|
||||
'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json',
|
||||
'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json',]
|
||||
name = 'squad'
|
||||
dirname = ''
|
||||
|
||||
|
@ -217,63 +219,72 @@ class SQuAD(CQA, data.Dataset):
|
|||
qas = paragraph['qas']
|
||||
for qa in qas:
|
||||
question = ' '.join(qa['question'].split())
|
||||
answer = qa['answers'][0]['text']
|
||||
squad_id = len(all_answers)
|
||||
all_answers.append([a['text'] for a in qa['answers']])
|
||||
#print('original: ', answer)
|
||||
answer_start = qa['answers'][0]['answer_start']
|
||||
answer_end = answer_start + len(answer)
|
||||
context_before_answer = context[:answer_start]
|
||||
context_after_answer = context[answer_end:]
|
||||
BEGIN = 'beginanswer '
|
||||
END = ' endanswer'
|
||||
|
||||
tagged_context = context_before_answer + BEGIN + answer + END + context_after_answer
|
||||
context_question = get_context_question(context, question)
|
||||
ex = data.Example.fromlist([tagged_context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
|
||||
if len(qa['answers']) == 0:
|
||||
answer = 'unanswerable'
|
||||
all_answers.append(['unanswerable'])
|
||||
context = ' '.join(context.split())
|
||||
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
|
||||
ex.context_spans = [-1, -1]
|
||||
ex.answer_start = -1
|
||||
ex.answer_end = -1
|
||||
else:
|
||||
answer = qa['answers'][0]['text']
|
||||
all_answers.append([a['text'] for a in qa['answers']])
|
||||
#print('original: ', answer)
|
||||
answer_start = qa['answers'][0]['answer_start']
|
||||
answer_end = answer_start + len(answer)
|
||||
context_before_answer = context[:answer_start]
|
||||
context_after_answer = context[answer_end:]
|
||||
BEGIN = 'beginanswer '
|
||||
END = ' endanswer'
|
||||
|
||||
tokenized_answer = ex.answer
|
||||
#print('tokenized: ', tokenized_answer)
|
||||
for xi, x in enumerate(ex.context):
|
||||
if BEGIN in x:
|
||||
answer_start = xi + 1
|
||||
ex.context[xi] = x.replace(BEGIN, '')
|
||||
if END in x:
|
||||
answer_end = xi
|
||||
ex.context[xi] = x.replace(END, '')
|
||||
new_context = []
|
||||
original_answer_start = answer_start
|
||||
original_answer_end = answer_end
|
||||
indexed_with_spaces = ex.context[answer_start:answer_end]
|
||||
if len(indexed_with_spaces) != len(tokenized_answer):
|
||||
import pdb; pdb.set_trace()
|
||||
tagged_context = context_before_answer + BEGIN + answer + END + context_after_answer
|
||||
ex = data.Example.fromlist([tagged_context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
|
||||
|
||||
# remove spaces
|
||||
for xi, x in enumerate(ex.context):
|
||||
if len(x.strip()) == 0:
|
||||
if xi <= original_answer_start:
|
||||
answer_start -= 1
|
||||
if xi < original_answer_end:
|
||||
answer_end -= 1
|
||||
else:
|
||||
new_context.append(x)
|
||||
ex.context = new_context
|
||||
ex.answer = [x for x in ex.answer if len(x.strip()) > 0]
|
||||
if len(ex.context[answer_start:answer_end]) != len(ex.answer):
|
||||
import pdb; pdb.set_trace()
|
||||
ex.context_spans = list(range(answer_start, answer_end))
|
||||
indexed_answer = ex.context[ex.context_spans[0]:ex.context_spans[-1]+1]
|
||||
if len(indexed_answer) != len(ex.answer):
|
||||
import pdb; pdb.set_trace()
|
||||
if field.eos_token is not None:
|
||||
ex.context_spans += [len(ex.context)]
|
||||
for context_idx, answer_word in zip(ex.context_spans, ex.answer):
|
||||
if context_idx == len(ex.context):
|
||||
continue
|
||||
if ex.context[context_idx] != answer_word:
|
||||
tokenized_answer = ex.answer
|
||||
#print('tokenized: ', tokenized_answer)
|
||||
for xi, x in enumerate(ex.context):
|
||||
if BEGIN in x:
|
||||
answer_start = xi + 1
|
||||
ex.context[xi] = x.replace(BEGIN, '')
|
||||
if END in x:
|
||||
answer_end = xi
|
||||
ex.context[xi] = x.replace(END, '')
|
||||
new_context = []
|
||||
original_answer_start = answer_start
|
||||
original_answer_end = answer_end
|
||||
indexed_with_spaces = ex.context[answer_start:answer_end]
|
||||
if len(indexed_with_spaces) != len(tokenized_answer):
|
||||
import pdb; pdb.set_trace()
|
||||
ex.answer_start = ex.context_spans[0]
|
||||
ex.answer_end = ex.context_spans[-1]
|
||||
|
||||
# remove spaces
|
||||
for xi, x in enumerate(ex.context):
|
||||
if len(x.strip()) == 0:
|
||||
if xi <= original_answer_start:
|
||||
answer_start -= 1
|
||||
if xi < original_answer_end:
|
||||
answer_end -= 1
|
||||
else:
|
||||
new_context.append(x)
|
||||
ex.context = new_context
|
||||
ex.answer = [x for x in ex.answer if len(x.strip()) > 0]
|
||||
if len(ex.context[answer_start:answer_end]) != len(ex.answer):
|
||||
import pdb; pdb.set_trace()
|
||||
ex.context_spans = list(range(answer_start, answer_end))
|
||||
indexed_answer = ex.context[ex.context_spans[0]:ex.context_spans[-1]+1]
|
||||
if len(indexed_answer) != len(ex.answer):
|
||||
import pdb; pdb.set_trace()
|
||||
if field.eos_token is not None:
|
||||
ex.context_spans += [len(ex.context)]
|
||||
for context_idx, answer_word in zip(ex.context_spans, ex.answer):
|
||||
if context_idx == len(ex.context):
|
||||
continue
|
||||
if ex.context[context_idx] != answer_word:
|
||||
import pdb; pdb.set_trace()
|
||||
ex.answer_start = ex.context_spans[0]
|
||||
ex.answer_end = ex.context_spans[-1]
|
||||
ex.squad_id = squad_id
|
||||
examples.append(ex)
|
||||
if subsample is not None and len(examples) > subsample:
|
||||
|
@ -299,7 +310,7 @@ class SQuAD(CQA, data.Dataset):
|
|||
|
||||
|
||||
@classmethod
|
||||
def splits(cls, fields, root='.data',
|
||||
def splits(cls, fields, root='.data', description='squad1.1',
|
||||
train='train', validation='dev', test=None, **kwargs):
|
||||
"""Create dataset objects for splits of the SQuAD dataset.
|
||||
Arguments:
|
||||
|
@ -313,7 +324,7 @@ class SQuAD(CQA, data.Dataset):
|
|||
assert test is None
|
||||
path = cls.download(root)
|
||||
|
||||
extension = 'v1.1.json'
|
||||
extension = 'v2.0.json' if '2.0' in description else 'v1.1.json'
|
||||
train = '-'.join([train, extension]) if train is not None else None
|
||||
validation = '-'.join([validation, extension]) if validation is not None else None
|
||||
|
||||
|
@ -1104,7 +1115,7 @@ class ZeroShotRE(CQA, data.Dataset):
|
|||
question = question.replace('XXX', subject)
|
||||
ex = {'context': context,
|
||||
'question': question,
|
||||
'answer': answer if len(answer) > 0 else 'Unanswerable'}
|
||||
'answer': answer if len(answer) > 0 else 'unanswerable'}
|
||||
split_file.write(json.dumps(ex)+'\n')
|
||||
|
||||
|
||||
|
|
2
util.py
2
util.py
|
@ -123,7 +123,7 @@ def get_splits(args, task, FIELD, **kwargs):
|
|||
fields=FIELD, root=args.data, **kwargs)
|
||||
elif 'squad' in task:
|
||||
split = torchtext.datasets.generic.SQuAD.splits(
|
||||
fields=FIELD, root=args.data, **kwargs)
|
||||
fields=FIELD, root=args.data, description=task, **kwargs)
|
||||
elif task == 'wikisql':
|
||||
split = torchtext.datasets.generic.WikiSQL.splits(
|
||||
fields=FIELD, root=args.data, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue