# # Copyright (c) 2018, Salesforce, Inc. # The Board of Trustees of the Leland Stanford Junior University # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # * Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # * Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # * Neither the name of the copyright holder nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import os import re import revtok import torch import io import csv import json import glob import hashlib import unicodedata import logging from ..text.torchtext.datasets import imdb from ..text.torchtext.datasets import translation from ..text.torchtext import data CONTEXT_SPECIAL = 'Context:' QUESTION_SPECIAL = 'Question:' logger = logging.getLogger(__name__) def get_context_question(context, question): return CONTEXT_SPECIAL + ' ' + context + ' ' + QUESTION_SPECIAL + ' ' + question class CQA(data.Dataset): fields = ['context', 'question', 'answer', 'context_special', 'question_special', 'context_question'] @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) class IMDb(CQA, imdb.IMDb): @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) def __init__(self, path, field, subsample=None, **kwargs): fields = [(x, field) for x in self.fields] examples = [] labels = {'neg': 'negative', 'pos': 'positive'} question = 'Is this review negative or positive?' cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) skip_cache_bool = kwargs.pop('skip_cache_bool') if os.path.exists(cache_name) and not skip_cache_bool: logger.info(f'Loading cached data from {cache_name}') examples = torch.load(cache_name) else: for label in ['pos', 'neg']: for fname in glob.iglob(os.path.join(path, label, '*.txt')): with open(fname, 'r') as f: context = f.readline() answer = labels[label] context_question = get_context_question(context, question) examples.append(data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)) if subsample is not None and len(examples) > subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) super(imdb.IMDb, self).__init__(examples, fields, **kwargs) @classmethod def splits(cls, fields, root='.data', train='train', validation=None, test='test', **kwargs): assert validation is None path = cls.download(root) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') aux_data = cls(os.path.join(path, 'aux'), fields, **kwargs) train_data = None if train is None else cls( os.path.join(path, f'{train}'), fields, **kwargs) test_data = None if test is None else cls( os.path.join(path, f'{test}'), fields, **kwargs) return tuple(d for d in (train_data, test_data, aux_data) if d is not None) class SST(CQA): urls = ['https://raw.githubusercontent.com/openai/generating-reviews-discovering-sentiment/master/data/train_binary_sent.csv', 'https://raw.githubusercontent.com/openai/generating-reviews-discovering-sentiment/master/data/dev_binary_sent.csv', 'https://raw.githubusercontent.com/openai/generating-reviews-discovering-sentiment/master/data/test_binary_sent.csv'] name = 'sst' dirname = '' @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) def __init__(self, path, field, subsample=None, **kwargs): fields = [(x, field) for x in self.fields] cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) examples = [] skip_cache_bool = kwargs.pop('skip_cache_bool') if os.path.exists(cache_name) and not skip_cache_bool: logger.info(f'Loading cached data from {cache_name}') examples = torch.load(cache_name) else: labels = ['negative', 'positive'] question = 'Is this review ' + labels[0] + ' or ' + labels[1] + '?' with io.open(os.path.expanduser(path), encoding='utf8') as f: next(f) for line in f: parsed = list(csv.reader([line.rstrip('\n')]))[0] context = parsed[-1] answer = labels[int(parsed[0])] context_question = get_context_question(context, question) examples.append(data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)) if subsample is not None and len(examples) > subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) self.examples = examples super().__init__(examples, fields, **kwargs) @classmethod def splits(cls, fields, root='.data', train='train', validation='dev', test='test', **kwargs): path = cls.download(root) postfix = f'_binary_sent.csv' aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') aux_data = cls(os.path.join(path, f'aux{postfix}'), fields, **kwargs) train_data = None if train is None else cls( os.path.join(path, f'{train}{postfix}'), fields, **kwargs) validation_data = None if validation is None else cls( os.path.join(path, f'{validation}{postfix}'), fields, **kwargs) test_data = None if test is None else cls( os.path.join(path, f'{test}{postfix}'), fields, **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) class TranslationDataset(translation.TranslationDataset): @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) def __init__(self, path, exts, field, subsample=None, tokenize=None, **kwargs): """Create a TranslationDataset given paths and fields. Arguments: path: Common prefix of paths to the data files for both languages. exts: A tuple containing the extension to path for each language. fields$: fields for handling all columns Remaining keyword arguments: Passed to the constructor of data.Dataset. """ fields = [(x, field) for x in self.fields] cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) skip_cache_bool = kwargs.pop('skip_cache_bool') if os.path.exists(cache_name) and not skip_cache_bool: logger.info(f'Loading cached data from {cache_name}') examples = torch.load(cache_name) else: langs = {'.de': 'German', '.en': 'English', '.fr': 'French', '.ar': 'Arabic', '.cs': 'Czech', '.tt': 'ThingTalk', '.fa': 'Farsi'} source, target = langs[exts[0]], langs[exts[1]] src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts) question = f'Translate from {source} to {target}' examples = [] with open(src_path) as src_file, open(trg_path) as trg_file: for src_line, trg_line in zip(src_file, trg_file): src_line, trg_line = src_line.strip(), trg_line.strip() if src_line != '' and trg_line != '': context = src_line answer = trg_line context_question = get_context_question(context, question) examples.append(data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields, tokenize=tokenize)) if subsample is not None and len(examples) >= subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) super(translation.TranslationDataset, self).__init__(examples, fields, **kwargs) class Multi30k(TranslationDataset, CQA, translation.Multi30k): pass class IWSLT(TranslationDataset, CQA, translation.IWSLT): pass class SQuAD(CQA, data.Dataset): @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) urls = ['https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json', 'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json', 'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json', 'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json',] name = 'squad' dirname = '' def __init__(self, path, field, subsample=None, **kwargs): fields = [(x, field) for x in self.fields] cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) examples, all_answers, q_ids = [], [], [] skip_cache_bool = kwargs.pop('skip_cache_bool') if os.path.exists(cache_name) and not skip_cache_bool: logger.info(f'Loading cached data from {cache_name}') examples, all_answers, q_ids = torch.load(cache_name) else: with open(os.path.expanduser(path)) as f: squad = json.load(f)['data'] for document in squad: title = document['title'] paragraphs = document['paragraphs'] for paragraph in paragraphs: context = paragraph['context'] qas = paragraph['qas'] for qa in qas: question = ' '.join(qa['question'].split()) q_ids.append(qa['id']) squad_id = len(all_answers) context_question = get_context_question(context, question) if len(qa['answers']) == 0: answer = 'unanswerable' all_answers.append(['unanswerable']) context = ' '.join(context.split()) ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) ex.context_spans = [-1, -1] ex.answer_start = -1 ex.answer_end = -1 else: answer = qa['answers'][0]['text'] all_answers.append([a['text'] for a in qa['answers']]) answer_start = qa['answers'][0]['answer_start'] answer_end = answer_start + len(answer) context_before_answer = context[:answer_start] context_after_answer = context[answer_end:] BEGIN = 'beginanswer ' END = ' endanswer' tagged_context = context_before_answer + BEGIN + answer + END + context_after_answer ex = data.Example.fromlist([tagged_context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) tokenized_answer = ex.answer for xi, x in enumerate(ex.context): if BEGIN in x: answer_start = xi + 1 ex.context[xi] = x.replace(BEGIN, '') if END in x: answer_end = xi ex.context[xi] = x.replace(END, '') new_context = [] original_answer_start = answer_start original_answer_end = answer_end indexed_with_spaces = ex.context[answer_start:answer_end] if len(indexed_with_spaces) != len(tokenized_answer): import pdb; pdb.set_trace() # remove spaces for xi, x in enumerate(ex.context): if len(x.strip()) == 0: if xi <= original_answer_start: answer_start -= 1 if xi < original_answer_end: answer_end -= 1 else: new_context.append(x) ex.context = new_context ex.answer = [x for x in ex.answer if len(x.strip()) > 0] if len(ex.context[answer_start:answer_end]) != len(ex.answer): import pdb; pdb.set_trace() ex.context_spans = list(range(answer_start, answer_end)) indexed_answer = ex.context[ex.context_spans[0]:ex.context_spans[-1]+1] if len(indexed_answer) != len(ex.answer): import pdb; pdb.set_trace() if field.eos_token is not None: ex.context_spans += [len(ex.context)] for context_idx, answer_word in zip(ex.context_spans, ex.answer): if context_idx == len(ex.context): continue if ex.context[context_idx] != answer_word: import pdb; pdb.set_trace() ex.answer_start = ex.context_spans[0] ex.answer_end = ex.context_spans[-1] ex.squad_id = squad_id examples.append(ex) if subsample is not None and len(examples) > subsample: break if subsample is not None and len(examples) > subsample: break if subsample is not None and len(examples) > subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save((examples, all_answers, q_ids), cache_name) FIELD = data.Field(batch_first=True, use_vocab=False, sequential=False, lower=False, numerical=True, eos_token=field.eos_token, init_token=field.init_token) fields.append(('context_spans', FIELD)) fields.append(('answer_start', FIELD)) fields.append(('answer_end', FIELD)) fields.append(('squad_id', FIELD)) super(SQuAD, self).__init__(examples, fields, **kwargs) self.all_answers = all_answers self.q_ids = q_ids @classmethod def splits(cls, fields, root='.data', description='squad1.1', train='train', validation='dev', test=None, **kwargs): """Create dataset objects for splits of the SQuAD dataset. Arguments: root: directory containing SQuAD data field: field for handling all columns train: The prefix of the train data. Default: 'train'. validation: The prefix of the validation data. Default: 'val'. Remaining keyword arguments: Passed to the splits method of Dataset. """ assert test is None path = cls.download(root) extension = 'v2.0.json' if '2.0' in description else 'v1.1.json' aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') aux = '-'.join(['aux', extension]) aux_data = cls(os.path.join(path, aux), fields, **kwargs) train = '-'.join([train, extension]) if train is not None else None validation = '-'.join([validation, extension]) if validation is not None else None train_data = None if train is None else cls( os.path.join(path, train), fields, **kwargs) validation_data = None if validation is None else cls( os.path.join(path, validation), fields, **kwargs) return tuple(d for d in (train_data, validation_data, aux_data) if d is not None) # https://github.com/abisee/cnn-dailymail/blob/8eace60f306dcbab30d1f1d715e379f07a3782db/make_datafiles.py dm_single_close_quote = u'\u2019' dm_double_close_quote = u'\u201d' END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence def fix_missing_period(line): """Adds a period to a line that is missing a period""" if "@highlight" in line: return line if line=="": return line if line[-1] in END_TOKENS: return line return line + "." class Summarization(CQA, data.Dataset): @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) def __init__(self, path, field, one_answer=True, subsample=None, **kwargs): fields = [(x, field) for x in self.fields] cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) examples = [] skip_cache_bool = kwargs.pop('skip_cache_bool') if os.path.exists(cache_name) and not skip_cache_bool: logger.info(f'Loading cached data from {cache_name}') examples = torch.load(cache_name) else: with open(os.path.expanduser(path)) as f: lines = f.readlines() for line in lines: ex = json.loads(line) context, question, answer = ex['context'], ex['question'], ex['answer'] context_question = get_context_question(context, question) ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) examples.append(ex) if subsample is not None and len(examples) >= subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) super(Summarization, self).__init__(examples, fields, **kwargs) @classmethod def cache_splits(cls, path): splits = ['training', 'validation', 'test'] for split in splits: missing_stories, collected_stories = 0, 0 split_file_name = os.path.join(path, f'{split}.jsonl') if os.path.exists(split_file_name): continue with open(split_file_name, 'w') as split_file: url_file_name = os.path.join(path, f'{cls.name}_wayback_{split}_urls.txt') with open(url_file_name) as url_file: for url in url_file: story_file_name = os.path.join(path, 'stories', f"{hashlib.sha1(url.strip().encode('utf-8')).hexdigest()}.story") try: story_file = open(story_file_name) except EnvironmentError as e: missing_stories += 1 logger.warning(e) if os.path.exists(split_file_name): os.remove(split_file_name) else: with story_file: article, highlight = [], [] is_highlight = False for line in story_file: line = line.strip() if line == "": continue line = fix_missing_period(line) if line.startswith("@highlight"): is_highlight = True elif "@highlight" in line: raise elif is_highlight: highlight.append(line) else: article.append(line) example = {'context': unicodedata.normalize('NFKC', ' '.join(article)), 'answer': unicodedata.normalize('NFKC', ' '.join(highlight)), 'question': 'What is the summary?'} split_file.write(json.dumps(example)+'\n') collected_stories += 1 if collected_stories % 1000 == 0: logger.debug(example) logger.warning(f'Missing {missing_stories} stories') logger.info(f'Collected {collected_stories} stories') @classmethod def splits(cls, fields, root='.data', train='training', validation='validation', test='test', **kwargs): path = cls.download(root) cls.cache_splits(path) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') aux_data = cls(os.path.join(path, 'auxiliary.jsonl'), fields, **kwargs) train_data = None if train is None else cls( os.path.join(path, 'training.jsonl'), fields, **kwargs) validation_data = None if validation is None else cls( os.path.join(path, 'validation.jsonl'), fields, one_answer=False, **kwargs) test_data = None if test is None else cls( os.path.join(path, 'test.jsonl'), fields, one_answer=False, **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) class DailyMail(Summarization): name = 'dailymail' dirname = 'dailymail' urls = [('https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs', 'dailymail_stories.tgz'), ('https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_training_urls.txt', 'dailymail/dailymail_wayback_training_urls.txt'), ('https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_validation_urls.txt', 'dailymail/dailymail_wayback_validation_urls.txt'), ('https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_test_urls.txt', 'dailymail/dailymail_wayback_test_urls.txt')] class CNN(Summarization): name = 'cnn' dirname = 'cnn' urls = [('https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ', 'cnn_stories.tgz'), ('https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_training_urls.txt', 'cnn/cnn_wayback_training_urls.txt'), ('https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_validation_urls.txt', 'cnn/cnn_wayback_validation_urls.txt'), ('https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_test_urls.txt', 'cnn/cnn_wayback_test_urls.txt')] class Query: #https://github.com/salesforce/WikiSQL/blob/c2ed4f9b22db1cc2721805d53e6e76e07e2ccbdc/lib/query.py#L10 agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG'] cond_ops = ['=', '>', '<', 'OP'] syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS'] def __init__(self, sel_index, agg_index, columns, conditions=tuple()): self.sel_index = sel_index self.agg_index = agg_index self.columns = columns self.conditions = list(conditions) def __repr__(self): rep = 'SELECT {agg} {sel} FROM table'.format( agg=self.agg_ops[self.agg_index], sel= self.columns[self.sel_index] if self.columns is not None else 'col{}'.format(self.sel_index), ) if self.conditions: rep += ' WHERE ' + ' AND '.join(['{} {} {}'.format(self.columns[i], self.cond_ops[o], v) for i, o, v in self.conditions]) return ' '.join(rep.split()) @classmethod def from_dict(cls, d, t): return cls(sel_index=d['sel'], agg_index=d['agg'], columns=t, conditions=d['conds']) class WikiSQL(CQA, data.Dataset): @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) urls = ['https://github.com/salesforce/WikiSQL/raw/master/data.tar.bz2'] name = 'wikisql' dirname = 'data' def __init__(self, path, field, query_as_question=False, subsample=None, **kwargs): fields = [(x, field) for x in self.fields] FIELD = data.Field(batch_first=True, use_vocab=False, sequential=False, lower=False, numerical=True, eos_token=field.eos_token, init_token=field.init_token) fields.append(('wikisql_id', FIELD)) cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', 'query_as_question' if query_as_question else 'query_as_context', os.path.basename(path), str(subsample)) skip_cache_bool = kwargs.pop('skip_cache_bool') if os.path.exists(cache_name) and not skip_cache_bool: logger.info(f'Loading cached data from {cache_name}') examples, all_answers = torch.load(cache_name) else: expanded_path = os.path.expanduser(path) table_path = os.path.splitext(expanded_path) table_path = table_path[0] + '.tables' + table_path[1] with open(table_path) as tables_file: tables = [json.loads(line) for line in tables_file] id_to_tables = {x['id']: x for x in tables} all_answers = [] examples = [] with open(expanded_path) as example_file: for idx, line in enumerate(example_file): entry = json.loads(line) human_query = entry['question'] table = id_to_tables[entry['table_id']] sql = entry['sql'] header = table['header'] answer = repr(Query.from_dict(sql, header)) context = (f'The table has columns {", ".join(table["header"])} ' + f'and key words {", ".join(Query.agg_ops[1:] + Query.cond_ops + Query.syms)}') if query_as_question: question = human_query else: question = 'What is the translation from English to SQL?' context += f'-- {human_query}' context_question = get_context_question(context, question) ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question, idx], fields) examples.append(ex) all_answers.append({'sql': sql, 'header': header, 'answer': answer, 'table': table}) if subsample is not None and len(examples) > subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save((examples, all_answers), cache_name) super(WikiSQL, self).__init__(examples, fields, **kwargs) self.all_answers = all_answers @classmethod def splits(cls, fields, root='.data', train='train.jsonl', validation='dev.jsonl', test='test.jsonl', **kwargs): """Create dataset objects for splits of the SQuAD dataset. Arguments: root: directory containing SQuAD data field: field for handling all columns train: The prefix of the train data. Default: 'train'. validation: The prefix of the validation data. Default: 'val'. Remaining keyword arguments: Passed to the splits method of Dataset. """ path = cls.download(root) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') aux_data = cls(os.path.join(path, 'aux'), fields, **kwargs) train_data = None if train is None else cls( os.path.join(path, train), fields, **kwargs) validation_data = None if validation is None else cls( os.path.join(path, validation), fields, **kwargs) test_data = None if test is None else cls( os.path.join(path, test), fields, **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) class SRL(CQA, data.Dataset): @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) urls = ['https://dada.cs.washington.edu/qasrl/data/wiki1.train.qa', 'https://dada.cs.washington.edu/qasrl/data/wiki1.dev.qa', 'https://dada.cs.washington.edu/qasrl/data/wiki1.test.qa'] name = 'srl' dirname = '' @classmethod def clean(cls, s): closing_punctuation = set([ ' .', ' ,', ' ;', ' !', ' ?', ' :', ' )', " 'll", " n't ", " %", " 't", " 's", " 'm", " 'd", " 're"]) opening_punctuation = set(['( ', '$ ']) both_sides = set([' - ']) s = ' '.join(s.split()).strip() s = s.replace('-LRB-', '(') s = s.replace('-RRB-', ')') s = s.replace('-LAB-', '<') s = s.replace('-RAB-', '>') s = s.replace('-AMP-', '&') s = s.replace('%pw', ' ') for p in closing_punctuation: s = s.replace(p, p.lstrip()) for p in opening_punctuation: s = s.replace(p, p.rstrip()) for p in both_sides: s = s.replace(p, p.strip()) s = s.replace('``', '') s = s.replace('`', '') s = s.replace("''", '') s = s.replace('“', '') s = s.replace('”', '') s = s.replace(" '", '') return ' '.join(s.split()).strip() def __init__(self, path, field, one_answer=True, subsample=None, **kwargs): fields = [(x, field) for x in self.fields] cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) examples, all_answers = [], [] skip_cache_bool = kwargs.pop('skip_cache_bool') if os.path.exists(cache_name) and not skip_cache_bool: logger.info(f'Loading cached data from {cache_name}') examples, all_answers = torch.load(cache_name) else: with open(os.path.expanduser(path)) as f: for line in f: ex = json.loads(line) t = ex['type'] aa = ex['all_answers'] context, question, answer = ex['context'], ex['question'], ex['answer'] context_question = get_context_question(context, question) ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) examples.append(ex) ex.squad_id = len(all_answers) all_answers.append(aa) if subsample is not None and len(examples) >= subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save((examples, all_answers), cache_name) FIELD = data.Field(batch_first=True, use_vocab=False, sequential=False, lower=False, numerical=True, eos_token=field.eos_token, init_token=field.init_token) fields.append(('squad_id', FIELD)) super(SRL, self).__init__(examples, fields, **kwargs) self.all_answers = all_answers @classmethod def cache_splits(cls, path, train='train', validation='dev', test='test'): splits = [train, validation, test] for split in splits: split_file_name = os.path.join(path, f'{split}.jsonl') if os.path.exists(split_file_name): continue wiki_file = os.path.join(path, f'wiki1.{split}.qa') with open(split_file_name, 'w') as split_file: with open(os.path.expanduser(wiki_file)) as f: def is_int(x): try: int(x) return True except: return False lines = [] for line in f.readlines(): line = ' '.join(line.split()).strip() if len(line) == 0: lines.append(line) continue if not 'WIKI1' in line.split('_')[0]: if not is_int(line.split()[0]) or len(line.split()) > 3: lines.append(line) new_example = True for line in lines: line = line.strip() if new_example: context = cls.clean(line) new_example = False continue if len(line) == 0: new_example = True continue question, answers = line.split('?') question = cls.clean(line.split('?')[0].replace(' _', '') +'?') answer = cls.clean(answers.split('###')[0]) all_answers = [cls.clean(x) for x in answers.split('###')] if answer not in context: low_answer = answer[0].lower() + answer[1:] up_answer = answer[0].upper() + answer[1:] if low_answer in context or up_answer in context: answer = low_answer if low_answer in context else up_answer else: if 'Darcy Burner' in answer: answer = 'Darcy Burner and other 2008 Democratic congressional candidates, in cooperation with some retired national security officials' elif 'E Street Band' in answer: answer = 'plan to work with the E Street Band again in the future' elif 'an electric sender' in answer: answer = 'an electronic sender' elif 'the US army' in answer: answer = 'the US Army' elif 'Rather than name the' in answer: answer = 'rather die than name the cause of his disease to his father' elif answer.lower() in context: answer = answer.lower() else: import pdb; pdb.set_trace() assert answer in context modified_all_answers = [] for a in all_answers: if a not in context: low_answer = a[0].lower() + a[1:] up_answer = a[0].upper() + a[1:] if low_answer in context or up_answer in context: a = low_answer if low_answer in context else up_answer else: if 'Darcy Burner' in a: a = 'Darcy Burner and other 2008 Democratic congressional candidates, in cooperation with some retired national security officials' elif 'E Street Band' in a: a = 'plan to work with the E Street Band again in the future' elif 'an electric sender' in a: a = 'an electronic sender' elif 'the US army' in a: a = 'the US Army' elif 'Rather than name the' in a: a = 'rather die than name the cause of his disease to his father' elif a.lower() in context: a = a.lower() else: import pdb; pdb.set_trace() assert a in context modified_all_answers.append(a) split_file.write(json.dumps({'context': context, 'question': question, 'answer': answer, 'type': 'wiki', 'all_answers': modified_all_answers})+'\n') @classmethod def splits(cls, fields, root='.data', train='train', validation='dev', test='test', **kwargs): path = cls.download(root) cls.cache_splits(path) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs) train_data = None if train is None else cls( os.path.join(path, f'{train}.jsonl'), fields, **kwargs) validation_data = None if validation is None else cls( os.path.join(path, f'{validation}.jsonl'), fields, one_answer=False, **kwargs) test_data = None if test is None else cls( os.path.join(path, f'{test}.jsonl'), fields, one_answer=False, **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) class WinogradSchema(CQA, data.Dataset): @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) urls = ['https://s3.amazonaws.com/research.metamind.io/decaNLP/data/schema.txt'] name = 'schema' dirname = '' def __init__(self, path, field, subsample=None, **kwargs): fields = [(x, field) for x in self.fields] cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) skip_cache_bool = kwargs.pop('skip_cache_bool') if os.path.exists(cache_name) and not skip_cache_bool: logger.info(f'Loading cached data from {cache_name}') examples = torch.load(cache_name) else: examples = [] with open(os.path.expanduser(path)) as f: for line in f: ex = json.loads(line) context, question, answer = ex['context'], ex['question'], ex['answer'] context_question = get_context_question(context, question) ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) examples.append(ex) if subsample is not None and len(examples) >= subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) super(WinogradSchema, self).__init__(examples, fields, **kwargs) @classmethod def cache_splits(cls, path): pattern = '\[.*\]' train_jsonl = os.path.expanduser(os.path.join(path, 'train.jsonl')) if os.path.exists(train_jsonl): return def get_both_schema(context): variations = [x[1:-1].split('/') for x in re.findall(pattern, context)] splits = re.split(pattern, context) results = [] for which_schema in range(2): vs = [v[which_schema] for v in variations] context = '' for idx in range(len(splits)): context += splits[idx] if idx < len(vs): context += vs[idx] results.append(context) return results schemas = [] with open(os.path.expanduser(os.path.join(path, 'schema.txt'))) as schema_file: schema = [] for line in schema_file: if len(line.split()) == 0: schemas.append(schema) schema = [] continue else: schema.append(line.strip()) examples = [] for schema in schemas: context, question, answer = schema contexts = get_both_schema(context) questions = get_both_schema(question) answers = answer.split('/') for idx in range(2): answer = answers[idx] question = questions[idx] + f' {answers[0]} or {answers[1]}?' examples.append({'context': contexts[idx], 'question': question, 'answer': answer}) traindev = examples[:-100] test = examples[-100:] train = traindev[:80] dev = traindev[80:] splits = ['train', 'validation', 'test'] for split, examples in zip(splits, [train, dev, test]): with open(os.path.expanduser(os.path.join(path, f'{split}.jsonl')), 'a') as split_file: for ex in examples: split_file.write(json.dumps(ex)+'\n') @classmethod def splits(cls, fields, root='.data', train='train', validation='validation', test='test', **kwargs): path = cls.download(root) cls.cache_splits(path) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs) train_data = None if train is None else cls( os.path.join(path, f'{train}.jsonl'), fields, **kwargs) validation_data = None if validation is None else cls( os.path.join(path, f'{validation}.jsonl'), fields, **kwargs) test_data = None if test is None else cls( os.path.join(path, f'{test}.jsonl'), fields, **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) class WOZ(CQA, data.Dataset): @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) urls = ['https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_train_en.json', 'https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_test_de.json', 'https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_test_en.json', 'https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_train_de.json', 'https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_train_en.json', 'https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_validate_de.json', 'https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_validate_en.json'] name = 'woz' dirname = '' def __init__(self, path, field, subsample=None, description='woz.en', **kwargs): fields = [(x, field) for x in self.fields] FIELD = data.Field(batch_first=True, use_vocab=False, sequential=False, lower=False, numerical=True, eos_token=field.eos_token, init_token=field.init_token) fields.append(('woz_id', FIELD)) examples, all_answers = [], [] cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample), description) skip_cache_bool = kwargs.pop('skip_cache_bool') if os.path.exists(cache_name) and not skip_cache_bool: logger.info(f'Loading cached data from {cache_name}') examples, all_answers = torch.load(cache_name) else: with open(os.path.expanduser(path)) as f: for woz_id, line in enumerate(f): ex = example_dict = json.loads(line) if example_dict['lang'] in description: context, question, answer = ex['context'], ex['question'], ex['answer'] context_question = get_context_question(context, question) all_answers.append((ex['lang_dialogue_turn'], answer)) ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question, woz_id], fields) examples.append(ex) if subsample is not None and len(examples) >= subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save((examples, all_answers), cache_name) super(WOZ, self).__init__(examples, fields, **kwargs) self.all_answers = all_answers @classmethod def cache_splits(cls, path, train='train', validation='validate', test='test'): train_jsonl = os.path.expanduser(os.path.join(path, 'train.jsonl')) if os.path.exists(train_jsonl): return splits = [train, validation, test] file_name_base = 'woz_{}_{}.json' question_base = "What is the change in state" for split in [train, validation, test]: with open (os.path.expanduser(os.path.join(path, f'{split}.jsonl')), 'a') as split_file: for lang in ['en', 'de']: file_path = file_name_base.format(split, lang) with open(os.path.expanduser(os.path.join(path, file_path))) as src_file: dialogues = json.loads(src_file.read()) for di, d in enumerate(dialogues): previous_state = {'inform': [], 'request': []} turns = d['dialogue'] for ti, t in enumerate(turns): question = 'What is the change in state?' actions = [] for act in t['system_acts']: if isinstance(act, list): act = ': '.join(act) actions.append(act) actions = ', '.join(actions) if len(actions) > 0: actions += ' -- ' context = actions + t['transcript'] belief_state = t['belief_state'] delta_state = {'inform': [], 'request': []} current_state = {'inform': [], 'request': []} for item in belief_state: if 'slots' in item: slots = item['slots'] for slot in slots: act = item['act'] if act == 'inform': current_state['inform'].append(slot) if not slot in previous_state['inform']: delta_state['inform'].append(slot) else: prev_slot = previous_state['inform'][previous_state['inform'].index(slot)] if prev_slot[1] != slot[1]: delta_state['inform'].append(slot) else: delta_state['request'].append(slot[1]) current_state['request'].append(slot[1]) previous_state = current_state answer = '' if len(delta_state['inform']) > 0: answer = ', '.join([f'{x[0]}: {x[1]}' for x in delta_state['inform']]) answer += ';' if len(delta_state['request']) > 0: answer += ' ' answer += ', '.join(delta_state['request']) ex = {'context': ' '.join(context.split()), 'question': ' '.join(question.split()), 'lang': lang, 'answer': answer if len(answer) > 1 else 'None', 'lang_dialogue_turn': f'{lang}_{di}_{ti}'} split_file.write(json.dumps(ex)+'\n') @classmethod def splits(cls, fields, root='.data', train='train', validation='validate', test='test', **kwargs): path = cls.download(root) cls.cache_splits(path) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs) train_data = None if train is None else cls( os.path.join(path, f'{train}.jsonl'), fields, **kwargs) validation_data = None if validation is None else cls( os.path.join(path, f'{validation}.jsonl'), fields, **kwargs) test_data = None if test is None else cls( os.path.join(path, f'{test}.jsonl'), fields, **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) class MultiNLI(CQA, data.Dataset): @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) urls = ['http://www.nyu.edu/projects/bowman/multinli/multinli_1.0.zip'] name = 'multinli' dirname = 'multinli_1.0' def __init__(self, path, field, subsample=None, description='multinli.in.out', **kwargs): fields = [(x, field) for x in self.fields] cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample), description) skip_cache_bool = kwargs.pop('skip_cache_bool') if os.path.exists(cache_name) and not skip_cache_bool: logger.info(f'Loading cached data from {cache_name}') examples = torch.load(cache_name) else: examples = [] with open(os.path.expanduser(path)) as f: for line in f: ex = example_dict = json.loads(line) if example_dict['subtask'] in description: context, question, answer = ex['context'], ex['question'], ex['answer'] context_question = get_context_question(context, question) ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) examples.append(ex) if subsample is not None and len(examples) >= subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) super(MultiNLI, self).__init__(examples, fields, **kwargs) @classmethod def cache_splits(cls, path, train='multinli_1.0_train', validation='mulinli_1.0_dev_{}', test='test'): train_jsonl = os.path.expanduser(os.path.join(path, 'train.jsonl')) if os.path.exists(train_jsonl): return with open(os.path.expanduser(os.path.join(path, f'train.jsonl')), 'a') as split_file: with open(os.path.expanduser(os.path.join(path, f'multinli_1.0_train.jsonl'))) as src_file: for line in src_file: ex = json.loads(line) ex = {'context': f'Premise: "{ex["sentence1"]}"', 'question': f'Hypothesis: "{ex["sentence2"]}" -- entailment, neutral, or contradiction?', 'answer': ex['gold_label'], 'subtask': 'multinli'} split_file.write(json.dumps(ex)+'\n') with open(os.path.expanduser(os.path.join(path, f'validation.jsonl')), 'a') as split_file: for subtask in ['matched', 'mismatched']: with open(os.path.expanduser(os.path.join(path, 'multinli_1.0_dev_{}.jsonl'.format(subtask)))) as src_file: for line in src_file: ex = json.loads(line) ex = {'context': f'Premise: "{ex["sentence1"]}"', 'question': f'Hypothesis: "{ex["sentence2"]}" -- entailment, neutral, or contradiction?', 'answer': ex['gold_label'], 'subtask': 'in' if subtask == 'matched' else 'out'} split_file.write(json.dumps(ex)+'\n') @classmethod def splits(cls, fields, root='.data', train='train', validation='validation', test='test', **kwargs): path = cls.download(root) cls.cache_splits(path) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs) train_data = None if train is None else cls( os.path.join(path, f'{train}.jsonl'), fields, **kwargs) validation_data = None if validation is None else cls( os.path.join(path, f'{validation}.jsonl'), fields, **kwargs) test_data = None if test is None else cls( os.path.join(path, f'{test}.jsonl'), fields, **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) class ZeroShotRE(CQA, data.Dataset): @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) urls = ['http://nlp.cs.washington.edu/zeroshot/relation_splits.tar.bz2'] dirname = 'relation_splits' name = 'zre' def __init__(self, path, field, subsample=None, **kwargs): fields = [(x, field) for x in self.fields] cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) skip_cache_bool = kwargs.pop('skip_cache_bool') if os.path.exists(cache_name) and not skip_cache_bool: logger.info(f'Loading cached data from {cache_name}') examples = torch.load(cache_name) else: examples = [] with open(os.path.expanduser(path)) as f: for line in f: ex = example_dict = json.loads(line) context, question, answer = ex['context'], ex['question'], ex['answer'] context_question = get_context_question(context, question) ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) examples.append(ex) if subsample is not None and len(examples) >= subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) super().__init__(examples, fields, **kwargs) @classmethod def cache_splits(cls, path, train='train', validation='dev', test='test'): train_jsonl = os.path.expanduser(os.path.join(path, f'{train}.jsonl')) if os.path.exists(train_jsonl): return base_file_name = '{}.0' for split in [train, validation, test]: src_file_name = base_file_name.format(split) with open(os.path.expanduser(os.path.join(path, f'{split}.jsonl')), 'a') as split_file: with open(os.path.expanduser(os.path.join(path, src_file_name))) as src_file: for line in src_file: split_line = line.split('\t') if len(split_line) == 4: answer = '' relation, question, subject, context = split_line else: relation, question, subject, context = split_line[:4] answer = ', '.join(split_line[4:]) question = question.replace('XXX', subject) ex = {'context': context, 'question': question, 'answer': answer if len(answer) > 0 else 'unanswerable'} split_file.write(json.dumps(ex)+'\n') @classmethod def splits(cls, fields, root='.data', train='train', validation='dev', test='test', **kwargs): path = cls.download(root) cls.cache_splits(path) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs) train_data = None if train is None else cls( os.path.join(path, f'{train}.jsonl'), fields, **kwargs) validation_data = None if validation is None else cls( os.path.join(path, f'{validation}.jsonl'), fields, **kwargs) test_data = None if test is None else cls( os.path.join(path, f'{test}.jsonl'), fields, **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) class OntoNotesNER(CQA, data.Dataset): @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) urls = ['http://conll.cemantix.org/2012/download/ids/english/all/train.id', 'http://conll.cemantix.org/2012/download/ids/english/all/development.id', 'http://conll.cemantix.org/2012/download/ids/english/all/test.id'] name = 'ontonotes.ner' dirname = '' @classmethod def clean(cls, s): closing_punctuation = set([ ' .', ' ,', ' ;', ' !', ' ?', ' :', ' )', " '", " n't ", " %"]) opening_punctuation = set(['( ', '$ ']) both_sides = set([' - ']) s = ' '.join(s.split()).strip() s = s.replace(' /.', ' .') s = s.replace(' /?', ' ?') s = s.replace('-LRB-', '(') s = s.replace('-RRB-', ')') s = s.replace('-LAB-', '<') s = s.replace('-RAB-', '>') s = s.replace('-AMP-', '&') s = s.replace('%pw', ' ') for p in closing_punctuation: s = s.replace(p, p.lstrip()) for p in opening_punctuation: s = s.replace(p, p.rstrip()) for p in both_sides: s = s.replace(p, p.strip()) s = s.replace('``', '"') s = s.replace("''", '"') quote_is_open = True quote_idx = s.find('"') raw = '' while quote_idx >= 0: start_enamex_open_idx = s.find(' -1: end_enamex_open_idx = s.find('">') + 2 if start_enamex_open_idx <= quote_idx <= end_enamex_open_idx: raw += s[:end_enamex_open_idx] s = s[end_enamex_open_idx:] quote_idx = s.find('"') continue if quote_is_open: raw += s[:quote_idx+1] s = s[quote_idx+1:].strip() quote_is_open = False else: raw += s[:quote_idx].strip() + '"' s = s[quote_idx+1:] quote_is_open = True quote_idx = s.find('"') raw += s return ' '.join(raw.split()).strip() def __init__(self, path, field, one_answer=True, subsample=None, path_to_files='.data/ontonotes-release-5.0/data/files', subtask='all', nones=True, **kwargs): fields = [(x, field) for x in self.fields] cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample), subtask, str(nones)) skip_cache_bool = kwargs.pop('skip_cache_bool') if os.path.exists(cache_name) and not skip_cache_bool: logger.info(f'Loading cached data from {cache_name}') examples = torch.load(cache_name) else: examples = [] with open(os.path.expanduser(path)) as f: for line in f: example_dict = json.loads(line) t = example_dict['type'] a = example_dict['answer'] if (subtask == 'both' or t == subtask): if a != 'None' or nones: ex = example_dict context, question, answer = ex['context'], ex['question'], ex['answer'] context_question = get_context_question(context, question) ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) examples.append(ex) if subsample is not None and len(examples) >= subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) super(OntoNotesNER, self).__init__(examples, fields, **kwargs) @classmethod def cache_splits(cls, path, path_to_files, train='train', validation='development', test='test'): label_to_answer = {'PERSON': 'person', 'NORP': 'political', 'FAC': 'facility', 'ORG': 'organization', 'GPE': 'geopolitical', 'LOC': 'location', 'PRODUCT': 'product', 'EVENT': 'event', 'WORK_OF_ART': 'artwork', 'LAW': 'legal', 'LANGUAGE': 'language', 'DATE': 'date', 'TIME': 'time', 'PERCENT': 'percentage', 'MONEY': 'monetary', 'QUANTITY': 'quantitative', 'ORDINAL': 'ordinal', 'CARDINAL': 'cardinal'} pluralize = {'person': 'persons', 'political': 'political', 'facility': 'facilities', 'organization': 'organizations', 'geopolitical': 'geopolitical', 'location': 'locations', 'product': 'products', 'event': 'events', 'artwork': 'artworks', 'legal': 'legal', 'language': 'languages', 'date': 'dates', 'time': 'times', 'percentage': 'percentages', 'monetary': 'monetary', 'quantitative': 'quantitative', 'ordinal': 'ordinal', 'cardinal': 'cardinal'} for split in [train, validation, test]: split_file_name = os.path.join(path, f'{split}.jsonl') if os.path.exists(split_file_name): continue id_file = os.path.join(path, f'{split}.id') num_file_ids = 0 examples = [] with open(split_file_name, 'w') as split_file: with open(os.path.expanduser(id_file)) as f: for file_id in f: example_file_name = os.path.join(os.path.expanduser(path_to_files), file_id.strip()) + '.name' if not os.path.exists(example_file_name) or 'annotations/tc/ch' in example_file_name: continue num_file_ids += 1 with open(example_file_name) as example_file: lines = [x.strip() for x in example_file.readlines() if 'DOC' not in x] for line in lines: original = line line = cls.clean(line) entities = [] while True: start_enamex_open_idx = line.find('') + 2 start_enamex_close_idx = line.find('') end_enamex_close_idx = start_enamex_close_idx + len('') enamex_open_tag = line[start_enamex_open_idx:end_enamex_open_idx] enamex_close_tag = line[start_enamex_close_idx:end_enamex_close_idx] before_entity = line[:start_enamex_open_idx] entity = line[end_enamex_open_idx:start_enamex_close_idx] after_entity = line[end_enamex_close_idx:] if 'S_OFF' in enamex_open_tag: s_off_start = enamex_open_tag.find('S_OFF="') s_off_end = enamex_open_tag.find('">') if 'E_OFF' not in enamex_open_tag else enamex_open_tag.find('" E_OFF') s_off = int(enamex_open_tag[s_off_start+len('S_OFF="'):s_off_end]) enamex_open_tag = enamex_open_tag[:s_off_start-2] + '">' before_entity += entity[:s_off] entity = entity[s_off:] if 'E_OFF' in enamex_open_tag: s_off_start = enamex_open_tag.find('E_OFF="') s_off_end = enamex_open_tag.find('">') s_off = int(enamex_open_tag[s_off_start+len('E_OFF="'):s_off_end]) enamex_open_tag = enamex_open_tag[:s_off_start-2] + '">' after_entity = entity[-s_off:] + after_entity entity = entity[:-s_off] label_start = enamex_open_tag.find('TYPE="') + len('TYPE="') label_end = enamex_open_tag.find('">') label = enamex_open_tag[label_start:label_end] assert label in label_to_answer offsets = (len(before_entity), len(before_entity) + len(entity)) entities.append({'entity': entity, 'char_offsets': offsets, 'label': label}) line = before_entity + entity + after_entity context = line.strip() is_no_good = False for entity_tuple in entities: entity = entity_tuple['entity'] start, end = entity_tuple['char_offsets'] if not context[start:end] == entity: is_no_good = True break if is_no_good: logger.warning('Throwing out example that looks poorly labeled: ', original.strip(), ' (', file_id.strip(), ')') continue question = 'What are the tags for all entities?' answer = '; '.join([f'{x["entity"]} -- {label_to_answer[x["label"]]}' for x in entities]) if len(answer) == 0: answer = 'None' split_file.write(json.dumps({'context': context, 'question': question, 'answer': answer, 'file_id': file_id.strip(), 'original': original.strip(), 'entity_list': entities, 'type': 'all'})+'\n') partial_question = 'Which entities are {}?' for lab, ans in label_to_answer.items(): question = partial_question.format(pluralize[ans]) entity_of_type_lab = [x['entity'] for x in entities if x['label'] == lab] answer = ', '.join(entity_of_type_lab) if len(answer) == 0: answer = 'None' split_file.write(json.dumps({'context': context, 'question': question, 'answer': answer, 'file_id': file_id.strip(), 'original': original.strip(), 'entity_list': entities, 'type': 'one', })+'\n') @classmethod def splits(cls, fields, root='.data', train='train', validation='development', test='test', **kwargs): path_to_files = os.path.join(root, 'ontonotes-release-5.0', 'data', 'files') assert os.path.exists(path_to_files) path = cls.download(root) cls.cache_splits(path, path_to_files) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs) train_data = None if train is None else cls( os.path.join(path, f'{train}.jsonl'), fields, **kwargs) validation_data = None if validation is None else cls( os.path.join(path, f'{validation}.jsonl'), fields, one_answer=False, **kwargs) test_data = None if test is None else cls( os.path.join(path, f'{test}.jsonl'), fields, one_answer=False, **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) class SNLI(CQA, data.Dataset): @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) urls = ['http://nlp.stanford.edu/projects/snli/snli_1.0.zip'] dirname = 'snli_1.0' name = 'snli' def __init__(self, path, field, subsample=None, **kwargs): fields = [(x, field) for x in self.fields] cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) skip_cache_bool = kwargs.pop('skip_cache_bool') if os.path.exists(cache_name) and not skip_cache_bool: logger.info(f'Loading cached data from {cache_name}') examples = torch.load(cache_name) else: examples = [] with open(os.path.expanduser(path)) as f: for line in f: example_dict = json.loads(line) ex = example_dict context, question, answer = ex['context'], ex['question'], ex['answer'] context_question = get_context_question(context, question) ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) examples.append(ex) if subsample is not None and len(examples) >= subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) super().__init__(examples, fields, **kwargs) @classmethod def cache_splits(cls, path, train='train', validation='dev', test='test'): train_jsonl = os.path.expanduser(os.path.join(path, f'{train}.jsonl')) if os.path.exists(train_jsonl): return base_file_name = 'snli_1.0_{}.jsonl' for split in [train, validation, test]: src_file_name = base_file_name.format(split) with open(os.path.expanduser(os.path.join(path, f'{split}.jsonl')), 'a') as split_file: with open(os.path.expanduser(os.path.join(path, src_file_name))) as src_file: for line in src_file: ex = json.loads(line) ex = {'context': f'Premise: "{ex["sentence1"]}"', 'question': f'Hypothesis: "{ex["sentence2"]}" -- entailment, neutral, or contradiction?', 'answer': ex['gold_label']} split_file.write(json.dumps(ex)+'\n') @classmethod def splits(cls, fields, root='.data', train='train', validation='dev', test='test', **kwargs): path = cls.download(root) cls.cache_splits(path) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs) train_data = None if train is None else cls( os.path.join(path, f'{train}.jsonl'), fields, **kwargs) validation_data = None if validation is None else cls( os.path.join(path, f'{validation}.jsonl'), fields, **kwargs) test_data = None if test is None else cls( os.path.join(path, f'{test}.jsonl'), fields, **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) class JSON(CQA, data.Dataset): @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) def __init__(self, path, field, subsample=None, **kwargs): fields = [(x, field) for x in self.fields] cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) examples = [] skip_cache_bool = kwargs.pop('skip_cache_bool') if os.path.exists(cache_name) and not skip_cache_bool: logger.info(f'Loading cached data from {cache_name}') examples = torch.load(cache_name) else: with open(os.path.expanduser(path)) as f: lines = f.readlines() for line in lines: ex = json.loads(line) context, question, answer = ex['context'], ex['question'], ex['answer'] context_question = get_context_question(context, question) ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields) examples.append(ex) if subsample is not None and len(examples) >= subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) super(JSON, self).__init__(examples, fields, **kwargs) @classmethod def splits(cls, fields, name, root='.data', train='train', validation='val', test='test', **kwargs): path = os.path.join(root, name) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs) train_data = None if train is None else cls( os.path.join(path, 'train.jsonl'), fields, **kwargs) validation_data = None if validation is None else cls( os.path.join(path, 'val.jsonl'), fields, **kwargs) test_data = None if test is None else cls( os.path.join(path, 'test.jsonl'), fields, **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None)