parent
652657acb5
commit
6059c11ac7
|
@ -42,18 +42,11 @@ import xml.etree.ElementTree as ET
|
|||
|
||||
from ..text import data
|
||||
|
||||
CONTEXT_SPECIAL = 'Context:'
|
||||
QUESTION_SPECIAL = 'Question:'
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_context_question(context, question):
|
||||
return CONTEXT_SPECIAL + ' ' + context + ' ' + QUESTION_SPECIAL + ' ' + question
|
||||
|
||||
|
||||
class CQA(data.Dataset):
|
||||
|
||||
fields = ['context', 'question', 'answer', 'context_special', 'question_special', 'context_question']
|
||||
fields = ['context', 'question', 'answer']
|
||||
|
||||
@staticmethod
|
||||
def sort_key(ex):
|
||||
|
@ -87,8 +80,7 @@ class IMDb(CQA):
|
|||
with open(fname, 'r') as f:
|
||||
context = f.readline()
|
||||
answer = labels[label]
|
||||
context_question = get_context_question(context, question)
|
||||
examples.append(data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields))
|
||||
examples.append(data.Example.fromlist([context, question, answer], fields))
|
||||
if subsample is not None and len(examples) > subsample:
|
||||
break
|
||||
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
|
||||
|
@ -148,8 +140,7 @@ class SST(CQA):
|
|||
parsed = list(csv.reader([line.rstrip('\n')]))[0]
|
||||
context = parsed[-1]
|
||||
answer = labels[int(parsed[0])]
|
||||
context_question = get_context_question(context, question)
|
||||
examples.append(data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields))
|
||||
examples.append(data.Example.fromlist([context, question, answer], fields))
|
||||
|
||||
if subsample is not None and len(examples) > subsample:
|
||||
break
|
||||
|
@ -219,8 +210,7 @@ class TranslationDataset(CQA):
|
|||
if src_line != '' and trg_line != '':
|
||||
context = src_line
|
||||
answer = trg_line
|
||||
context_question = get_context_question(context, question)
|
||||
examples.append(data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields, tokenize=tokenize))
|
||||
examples.append(data.Example.fromlist([context, question, answer], fields, tokenize=tokenize))
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
break
|
||||
|
||||
|
@ -385,12 +375,11 @@ class SQuAD(CQA, data.Dataset):
|
|||
question = ' '.join(qa['question'].split())
|
||||
q_ids.append(qa['id'])
|
||||
squad_id = len(all_answers)
|
||||
context_question = get_context_question(context, question)
|
||||
if len(qa['answers']) == 0:
|
||||
answer = 'unanswerable'
|
||||
all_answers.append(['unanswerable'])
|
||||
context = ' '.join(context.split())
|
||||
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
|
||||
ex = data.Example.fromlist([context, question, answer], fields)
|
||||
ex.context_spans = [-1, -1]
|
||||
ex.answer_start = -1
|
||||
ex.answer_end = -1
|
||||
|
@ -405,7 +394,7 @@ class SQuAD(CQA, data.Dataset):
|
|||
END = ' endanswer'
|
||||
|
||||
tagged_context = context_before_answer + BEGIN + answer + END + context_after_answer
|
||||
ex = data.Example.fromlist([tagged_context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
|
||||
ex = data.Example.fromlist([tagged_context, question, answer], fields)
|
||||
|
||||
tokenized_answer = ex.answer
|
||||
for xi, x in enumerate(ex.context):
|
||||
|
@ -542,8 +531,7 @@ class Summarization(CQA, data.Dataset):
|
|||
for line in lines:
|
||||
ex = json.loads(line)
|
||||
context, question, answer = ex['context'], ex['question'], ex['answer']
|
||||
context_question = get_context_question(context, question)
|
||||
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
|
||||
ex = data.Example.fromlist([context, question, answer], fields)
|
||||
examples.append(ex)
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
break
|
||||
|
@ -719,8 +707,7 @@ class WikiSQL(CQA, data.Dataset):
|
|||
else:
|
||||
question = 'What is the translation from English to SQL?'
|
||||
context += f'-- {human_query}'
|
||||
context_question = get_context_question(context, question)
|
||||
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question, idx], fields)
|
||||
ex = data.Example.fromlist([context, question, answer, idx], fields)
|
||||
examples.append(ex)
|
||||
all_answers.append({'sql': sql, 'header': header, 'answer': answer, 'table': table})
|
||||
if subsample is not None and len(examples) > subsample:
|
||||
|
@ -820,8 +807,7 @@ class SRL(CQA, data.Dataset):
|
|||
t = ex['type']
|
||||
aa = ex['all_answers']
|
||||
context, question, answer = ex['context'], ex['question'], ex['answer']
|
||||
context_question = get_context_question(context, question)
|
||||
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
|
||||
ex = data.Example.fromlist([context, question, answer], fields)
|
||||
examples.append(ex)
|
||||
ex.squad_id = len(all_answers)
|
||||
all_answers.append(aa)
|
||||
|
@ -977,8 +963,7 @@ class WinogradSchema(CQA, data.Dataset):
|
|||
for line in f:
|
||||
ex = json.loads(line)
|
||||
context, question, answer = ex['context'], ex['question'], ex['answer']
|
||||
context_question = get_context_question(context, question)
|
||||
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
|
||||
ex = data.Example.fromlist([context, question, answer], fields)
|
||||
examples.append(ex)
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
break
|
||||
|
@ -1101,9 +1086,8 @@ class WOZ(CQA, data.Dataset):
|
|||
ex = example_dict = json.loads(line)
|
||||
if example_dict['lang'] in description:
|
||||
context, question, answer = ex['context'], ex['question'], ex['answer']
|
||||
context_question = get_context_question(context, question)
|
||||
all_answers.append((ex['lang_dialogue_turn'], answer))
|
||||
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question, woz_id], fields)
|
||||
ex = data.Example.fromlist([context, question, answer, woz_id], fields)
|
||||
examples.append(ex)
|
||||
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
|
@ -1225,8 +1209,7 @@ class MultiNLI(CQA, data.Dataset):
|
|||
ex = example_dict = json.loads(line)
|
||||
if example_dict['subtask'] in description:
|
||||
context, question, answer = ex['context'], ex['question'], ex['answer']
|
||||
context_question = get_context_question(context, question)
|
||||
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
|
||||
ex = data.Example.fromlist([context, question, answer], fields)
|
||||
examples.append(ex)
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
break
|
||||
|
@ -1310,8 +1293,7 @@ class ZeroShotRE(CQA, data.Dataset):
|
|||
for line in f:
|
||||
ex = example_dict = json.loads(line)
|
||||
context, question, answer = ex['context'], ex['question'], ex['answer']
|
||||
context_question = get_context_question(context, question)
|
||||
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
|
||||
ex = data.Example.fromlist([context, question, answer], fields)
|
||||
examples.append(ex)
|
||||
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
|
@ -1448,8 +1430,7 @@ class OntoNotesNER(CQA, data.Dataset):
|
|||
if a != 'None' or nones:
|
||||
ex = example_dict
|
||||
context, question, answer = ex['context'], ex['question'], ex['answer']
|
||||
context_question = get_context_question(context, question)
|
||||
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
|
||||
ex = data.Example.fromlist([context, question, answer], fields)
|
||||
examples.append(ex)
|
||||
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
|
@ -1636,8 +1617,7 @@ class SNLI(CQA, data.Dataset):
|
|||
example_dict = json.loads(line)
|
||||
ex = example_dict
|
||||
context, question, answer = ex['context'], ex['question'], ex['answer']
|
||||
context_question = get_context_question(context, question)
|
||||
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
|
||||
ex = data.Example.fromlist([context, question, answer], fields)
|
||||
examples.append(ex)
|
||||
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
|
@ -1709,8 +1689,7 @@ class JSON(CQA, data.Dataset):
|
|||
for line in lines:
|
||||
ex = json.loads(line)
|
||||
context, question, answer = ex['context'], ex['question'], ex['answer']
|
||||
context_question = get_context_question(context, question)
|
||||
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
|
||||
ex = data.Example.fromlist([context, question, answer], fields)
|
||||
examples.append(ex)
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
break
|
||||
|
|
|
@ -92,7 +92,6 @@ def preprocess_examples(args, tasks, splits, field, logger=None, train=True):
|
|||
for ex in s.examples[:10]:
|
||||
logger.info('Context: ' + ' '.join([token.strip() for token in ex.context]))
|
||||
logger.info('Question: ' + ' '.join([token.strip() for token in ex.question]))
|
||||
logger.info(' '.join([token.strip() for token in ex.context_question]))
|
||||
logger.info('Answer: ' + ' '.join([token.strip() for token in ex.answer]))
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue