This commit is contained in:
Bryan Marcus McCann 2018-08-25 00:53:01 +00:00
parent 1c5b5f3bdf
commit f762fd83c5
2 changed files with 68 additions and 57 deletions

View File

@ -195,7 +195,9 @@ class SQuAD(CQA, data.Dataset):
return data.interleave_keys(len(ex.context), len(ex.answer))
urls = ['https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json',
'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json']
'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json',
'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json',
'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json',]
name = 'squad'
dirname = ''
@ -217,63 +219,72 @@ class SQuAD(CQA, data.Dataset):
qas = paragraph['qas']
for qa in qas:
question = ' '.join(qa['question'].split())
answer = qa['answers'][0]['text']
squad_id = len(all_answers)
all_answers.append([a['text'] for a in qa['answers']])
#print('original: ', answer)
answer_start = qa['answers'][0]['answer_start']
answer_end = answer_start + len(answer)
context_before_answer = context[:answer_start]
context_after_answer = context[answer_end:]
BEGIN = 'beginanswer '
END = ' endanswer'
tagged_context = context_before_answer + BEGIN + answer + END + context_after_answer
context_question = get_context_question(context, question)
ex = data.Example.fromlist([tagged_context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
if len(qa['answers']) == 0:
answer = 'unanswerable'
all_answers.append(['unanswerable'])
context = ' '.join(context.split())
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
ex.context_spans = [-1, -1]
ex.answer_start = -1
ex.answer_end = -1
else:
answer = qa['answers'][0]['text']
all_answers.append([a['text'] for a in qa['answers']])
#print('original: ', answer)
answer_start = qa['answers'][0]['answer_start']
answer_end = answer_start + len(answer)
context_before_answer = context[:answer_start]
context_after_answer = context[answer_end:]
BEGIN = 'beginanswer '
END = ' endanswer'
tokenized_answer = ex.answer
#print('tokenized: ', tokenized_answer)
for xi, x in enumerate(ex.context):
if BEGIN in x:
answer_start = xi + 1
ex.context[xi] = x.replace(BEGIN, '')
if END in x:
answer_end = xi
ex.context[xi] = x.replace(END, '')
new_context = []
original_answer_start = answer_start
original_answer_end = answer_end
indexed_with_spaces = ex.context[answer_start:answer_end]
if len(indexed_with_spaces) != len(tokenized_answer):
import pdb; pdb.set_trace()
tagged_context = context_before_answer + BEGIN + answer + END + context_after_answer
ex = data.Example.fromlist([tagged_context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
# remove spaces
for xi, x in enumerate(ex.context):
if len(x.strip()) == 0:
if xi <= original_answer_start:
answer_start -= 1
if xi < original_answer_end:
answer_end -= 1
else:
new_context.append(x)
ex.context = new_context
ex.answer = [x for x in ex.answer if len(x.strip()) > 0]
if len(ex.context[answer_start:answer_end]) != len(ex.answer):
import pdb; pdb.set_trace()
ex.context_spans = list(range(answer_start, answer_end))
indexed_answer = ex.context[ex.context_spans[0]:ex.context_spans[-1]+1]
if len(indexed_answer) != len(ex.answer):
import pdb; pdb.set_trace()
if field.eos_token is not None:
ex.context_spans += [len(ex.context)]
for context_idx, answer_word in zip(ex.context_spans, ex.answer):
if context_idx == len(ex.context):
continue
if ex.context[context_idx] != answer_word:
tokenized_answer = ex.answer
#print('tokenized: ', tokenized_answer)
for xi, x in enumerate(ex.context):
if BEGIN in x:
answer_start = xi + 1
ex.context[xi] = x.replace(BEGIN, '')
if END in x:
answer_end = xi
ex.context[xi] = x.replace(END, '')
new_context = []
original_answer_start = answer_start
original_answer_end = answer_end
indexed_with_spaces = ex.context[answer_start:answer_end]
if len(indexed_with_spaces) != len(tokenized_answer):
import pdb; pdb.set_trace()
ex.answer_start = ex.context_spans[0]
ex.answer_end = ex.context_spans[-1]
# remove spaces
for xi, x in enumerate(ex.context):
if len(x.strip()) == 0:
if xi <= original_answer_start:
answer_start -= 1
if xi < original_answer_end:
answer_end -= 1
else:
new_context.append(x)
ex.context = new_context
ex.answer = [x for x in ex.answer if len(x.strip()) > 0]
if len(ex.context[answer_start:answer_end]) != len(ex.answer):
import pdb; pdb.set_trace()
ex.context_spans = list(range(answer_start, answer_end))
indexed_answer = ex.context[ex.context_spans[0]:ex.context_spans[-1]+1]
if len(indexed_answer) != len(ex.answer):
import pdb; pdb.set_trace()
if field.eos_token is not None:
ex.context_spans += [len(ex.context)]
for context_idx, answer_word in zip(ex.context_spans, ex.answer):
if context_idx == len(ex.context):
continue
if ex.context[context_idx] != answer_word:
import pdb; pdb.set_trace()
ex.answer_start = ex.context_spans[0]
ex.answer_end = ex.context_spans[-1]
ex.squad_id = squad_id
examples.append(ex)
if subsample is not None and len(examples) > subsample:
@ -299,7 +310,7 @@ class SQuAD(CQA, data.Dataset):
@classmethod
def splits(cls, fields, root='.data',
def splits(cls, fields, root='.data', description='squad1.1',
train='train', validation='dev', test=None, **kwargs):
"""Create dataset objects for splits of the SQuAD dataset.
Arguments:
@ -313,7 +324,7 @@ class SQuAD(CQA, data.Dataset):
assert test is None
path = cls.download(root)
extension = 'v1.1.json'
extension = 'v2.0.json' if '2.0' in description else 'v1.1.json'
train = '-'.join([train, extension]) if train is not None else None
validation = '-'.join([validation, extension]) if validation is not None else None
@ -1104,7 +1115,7 @@ class ZeroShotRE(CQA, data.Dataset):
question = question.replace('XXX', subject)
ex = {'context': context,
'question': question,
'answer': answer if len(answer) > 0 else 'Unanswerable'}
'answer': answer if len(answer) > 0 else 'unanswerable'}
split_file.write(json.dumps(ex)+'\n')

View File

@ -123,7 +123,7 @@ def get_splits(args, task, FIELD, **kwargs):
fields=FIELD, root=args.data, **kwargs)
elif 'squad' in task:
split = torchtext.datasets.generic.SQuAD.splits(
fields=FIELD, root=args.data, **kwargs)
fields=FIELD, root=args.data, description=task, **kwargs)
elif task == 'wikisql':
split = torchtext.datasets.generic.WikiSQL.splits(
fields=FIELD, root=args.data, **kwargs)