diff --git a/decanlp/tasks/generic_dataset.py b/decanlp/tasks/generic_dataset.py index af82277b..7bd18b8d 100644 --- a/decanlp/tasks/generic_dataset.py +++ b/decanlp/tasks/generic_dataset.py @@ -42,18 +42,11 @@ import xml.etree.ElementTree as ET from ..text import data -CONTEXT_SPECIAL = 'Context:' -QUESTION_SPECIAL = 'Question:' - logger = logging.getLogger(__name__) -def get_context_question(context, question): - return CONTEXT_SPECIAL + ' ' + context + ' ' + QUESTION_SPECIAL + ' ' + question - - class CQA(data.Dataset): - fields = ['context', 'question', 'answer', 'context_special', 'question_special', 'context_question'] + fields = ['context', 'question', 'answer'] @staticmethod def sort_key(ex): @@ -87,8 +80,7 @@ class IMDb(CQA): with open(fname, 'r') as f: context = f.readline() answer = labels[label] - context_question = get_context_question(context, question) - examples.append(data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)) + examples.append(data.Example.fromlist([context, question, answer], fields)) if subsample is not None and len(examples) > subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) @@ -148,8 +140,7 @@ class SST(CQA): parsed = list(csv.reader([line.rstrip('\n')]))[0] context = parsed[-1] answer = labels[int(parsed[0])] - context_question = get_context_question(context, question) - examples.append(data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)) + examples.append(data.Example.fromlist([context, question, answer], fields)) if subsample is not None and len(examples) > subsample: break @@ -219,8 +210,7 @@ class TranslationDataset(CQA): if src_line != '' and trg_line != '': context = src_line answer = trg_line - context_question = get_context_question(context, question) - examples.append(data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields, tokenize=tokenize)) + examples.append(data.Example.fromlist([context, question, answer], fields, tokenize=tokenize)) if subsample is not None and len(examples) >= subsample: break @@ -385,12 +375,11 @@ class SQuAD(CQA, data.Dataset): question = ' '.join(qa['question'].split()) q_ids.append(qa['id']) squad_id = len(all_answers) - context_question = get_context_question(context, question) if len(qa['answers']) == 0: answer = 'unanswerable' all_answers.append(['unanswerable']) context = ' '.join(context.split()) - ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) + ex = data.Example.fromlist([context, question, answer], fields) ex.context_spans = [-1, -1] ex.answer_start = -1 ex.answer_end = -1 @@ -405,7 +394,7 @@ class SQuAD(CQA, data.Dataset): END = ' endanswer' tagged_context = context_before_answer + BEGIN + answer + END + context_after_answer - ex = data.Example.fromlist([tagged_context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) + ex = data.Example.fromlist([tagged_context, question, answer], fields) tokenized_answer = ex.answer for xi, x in enumerate(ex.context): @@ -542,8 +531,7 @@ class Summarization(CQA, data.Dataset): for line in lines: ex = json.loads(line) context, question, answer = ex['context'], ex['question'], ex['answer'] - context_question = get_context_question(context, question) - ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) + ex = data.Example.fromlist([context, question, answer], fields) examples.append(ex) if subsample is not None and len(examples) >= subsample: break @@ -719,8 +707,7 @@ class WikiSQL(CQA, data.Dataset): else: question = 'What is the translation from English to SQL?' context += f'-- {human_query}' - context_question = get_context_question(context, question) - ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question, idx], fields) + ex = data.Example.fromlist([context, question, answer, idx], fields) examples.append(ex) all_answers.append({'sql': sql, 'header': header, 'answer': answer, 'table': table}) if subsample is not None and len(examples) > subsample: @@ -820,8 +807,7 @@ class SRL(CQA, data.Dataset): t = ex['type'] aa = ex['all_answers'] context, question, answer = ex['context'], ex['question'], ex['answer'] - context_question = get_context_question(context, question) - ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) + ex = data.Example.fromlist([context, question, answer], fields) examples.append(ex) ex.squad_id = len(all_answers) all_answers.append(aa) @@ -977,8 +963,7 @@ class WinogradSchema(CQA, data.Dataset): for line in f: ex = json.loads(line) context, question, answer = ex['context'], ex['question'], ex['answer'] - context_question = get_context_question(context, question) - ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) + ex = data.Example.fromlist([context, question, answer], fields) examples.append(ex) if subsample is not None and len(examples) >= subsample: break @@ -1101,9 +1086,8 @@ class WOZ(CQA, data.Dataset): ex = example_dict = json.loads(line) if example_dict['lang'] in description: context, question, answer = ex['context'], ex['question'], ex['answer'] - context_question = get_context_question(context, question) all_answers.append((ex['lang_dialogue_turn'], answer)) - ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question, woz_id], fields) + ex = data.Example.fromlist([context, question, answer, woz_id], fields) examples.append(ex) if subsample is not None and len(examples) >= subsample: @@ -1225,8 +1209,7 @@ class MultiNLI(CQA, data.Dataset): ex = example_dict = json.loads(line) if example_dict['subtask'] in description: context, question, answer = ex['context'], ex['question'], ex['answer'] - context_question = get_context_question(context, question) - ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) + ex = data.Example.fromlist([context, question, answer], fields) examples.append(ex) if subsample is not None and len(examples) >= subsample: break @@ -1310,8 +1293,7 @@ class ZeroShotRE(CQA, data.Dataset): for line in f: ex = example_dict = json.loads(line) context, question, answer = ex['context'], ex['question'], ex['answer'] - context_question = get_context_question(context, question) - ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) + ex = data.Example.fromlist([context, question, answer], fields) examples.append(ex) if subsample is not None and len(examples) >= subsample: @@ -1448,8 +1430,7 @@ class OntoNotesNER(CQA, data.Dataset): if a != 'None' or nones: ex = example_dict context, question, answer = ex['context'], ex['question'], ex['answer'] - context_question = get_context_question(context, question) - ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) + ex = data.Example.fromlist([context, question, answer], fields) examples.append(ex) if subsample is not None and len(examples) >= subsample: @@ -1636,8 +1617,7 @@ class SNLI(CQA, data.Dataset): example_dict = json.loads(line) ex = example_dict context, question, answer = ex['context'], ex['question'], ex['answer'] - context_question = get_context_question(context, question) - ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) + ex = data.Example.fromlist([context, question, answer], fields) examples.append(ex) if subsample is not None and len(examples) >= subsample: @@ -1709,8 +1689,7 @@ class JSON(CQA, data.Dataset): for line in lines: ex = json.loads(line) context, question, answer = ex['context'], ex['question'], ex['answer'] - context_question = get_context_question(context, question) - ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) + ex = data.Example.fromlist([context, question, answer], fields) examples.append(ex) if subsample is not None and len(examples) >= subsample: break diff --git a/decanlp/util.py b/decanlp/util.py index d6e06f69..14a3323b 100644 --- a/decanlp/util.py +++ b/decanlp/util.py @@ -92,7 +92,6 @@ def preprocess_examples(args, tasks, splits, field, logger=None, train=True): for ex in s.examples[:10]: logger.info('Context: ' + ' '.join([token.strip() for token in ex.context])) logger.info('Question: ' + ' '.join([token.strip() for token in ex.question])) - logger.info(' '.join([token.strip() for token in ex.context_question])) logger.info('Answer: ' + ' '.join([token.strip() for token in ex.answer]))