genienlp/decanlp/tasks/generic_dataset.py

1632 lines
79 KiB
Python
Raw Normal View History

#
# Copyright (c) 2018, Salesforce, Inc.
# The Board of Trustees of the Leland Stanford Junior University
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2018-06-20 06:22:34 +00:00
import os
import re
import revtok
import torch
import io
import csv
import json
import glob
import hashlib
import unicodedata
import logging
2018-06-20 06:22:34 +00:00
from ..text.torchtext.datasets import imdb
from ..text.torchtext.datasets import translation
2018-06-20 06:22:34 +00:00
from ..text.torchtext import data
2018-06-20 06:22:34 +00:00
CONTEXT_SPECIAL = 'Context:'
QUESTION_SPECIAL = 'Question:'
logger = logging.getLogger(__name__)
2018-06-20 06:22:34 +00:00
def get_context_question(context, question):
return CONTEXT_SPECIAL + ' ' + context + ' ' + QUESTION_SPECIAL + ' ' + question
2018-06-20 06:22:34 +00:00
class CQA(data.Dataset):
fields = ['context', 'question', 'answer', 'context_special', 'question_special', 'context_question']
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
class IMDb(CQA, imdb.IMDb):
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
def __init__(self, path, field, subsample=None, **kwargs):
fields = [(x, field) for x in self.fields]
examples = []
labels = {'neg': 'negative', 'pos': 'positive'}
question = 'Is this review negative or positive?'
cached_path = kwargs.pop('cached_path')
2019-02-20 00:21:34 +00:00
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool:
logger.info(f'Loading cached data from {cache_name}')
2018-06-20 06:22:34 +00:00
examples = torch.load(cache_name)
else:
for label in ['pos', 'neg']:
for fname in glob.iglob(os.path.join(path, label, '*.txt')):
with open(fname, 'r') as f:
context = f.readline()
answer = labels[label]
context_question = get_context_question(context, question)
examples.append(data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields))
if subsample is not None and len(examples) > subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
2018-06-20 06:22:34 +00:00
torch.save(examples, cache_name)
super(imdb.IMDb, self).__init__(examples, fields, **kwargs)
@classmethod
def splits(cls, fields, root='.data',
train='train', validation=None, test='test', **kwargs):
assert validation is None
path = cls.download(root)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux'), fields, **kwargs)
2018-06-20 06:22:34 +00:00
train_data = None if train is None else cls(
os.path.join(path, f'{train}'), fields, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, f'{test}'), fields, **kwargs)
return tuple(d for d in (train_data, test_data, aux_data)
2018-06-20 06:22:34 +00:00
if d is not None)
class SST(CQA):
urls = ['https://raw.githubusercontent.com/openai/generating-reviews-discovering-sentiment/master/data/train_binary_sent.csv',
'https://raw.githubusercontent.com/openai/generating-reviews-discovering-sentiment/master/data/dev_binary_sent.csv',
'https://raw.githubusercontent.com/openai/generating-reviews-discovering-sentiment/master/data/test_binary_sent.csv']
name = 'sst'
dirname = ''
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
def __init__(self, path, field, subsample=None, **kwargs):
fields = [(x, field) for x in self.fields]
cached_path = kwargs.pop('cached_path')
2019-02-20 00:21:34 +00:00
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
2018-06-20 06:22:34 +00:00
examples = []
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool:
logger.info(f'Loading cached data from {cache_name}')
2018-06-20 06:22:34 +00:00
examples = torch.load(cache_name)
else:
labels = ['negative', 'positive']
question = 'Is this review ' + labels[0] + ' or ' + labels[1] + '?'
with io.open(os.path.expanduser(path), encoding='utf8') as f:
next(f)
for line in f:
parsed = list(csv.reader([line.rstrip('\n')]))[0]
context = parsed[-1]
answer = labels[int(parsed[0])]
context_question = get_context_question(context, question)
examples.append(data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields))
if subsample is not None and len(examples) > subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
2018-06-20 06:22:34 +00:00
torch.save(examples, cache_name)
self.examples = examples
super().__init__(examples, fields, **kwargs)
@classmethod
def splits(cls, fields, root='.data',
train='train', validation='dev', test='test', **kwargs):
path = cls.download(root)
postfix = f'_binary_sent.csv'
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, f'aux{postfix}'), fields, **kwargs)
2018-06-20 06:22:34 +00:00
train_data = None if train is None else cls(
os.path.join(path, f'{train}{postfix}'), fields, **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, f'{validation}{postfix}'), fields, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, f'{test}{postfix}'), fields, **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
2018-06-20 06:22:34 +00:00
if d is not None)
class TranslationDataset(translation.TranslationDataset):
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
def __init__(self, path, exts, field, subsample=None, tokenize=None, **kwargs):
2018-06-20 06:22:34 +00:00
"""Create a TranslationDataset given paths and fields.
Arguments:
path: Common prefix of paths to the data files for both languages.
exts: A tuple containing the extension to path for each language.
fields$: fields for handling all columns
Remaining keyword arguments: Passed to the constructor of
data.Dataset.
"""
fields = [(x, field) for x in self.fields]
cached_path = kwargs.pop('cached_path')
2019-02-20 00:21:34 +00:00
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
2018-06-20 06:22:34 +00:00
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool:
logger.info(f'Loading cached data from {cache_name}')
2018-06-20 06:22:34 +00:00
examples = torch.load(cache_name)
else:
2019-04-29 20:56:39 +00:00
langs = {'.de': 'German', '.en': 'English', '.fr': 'French', '.ar': 'Arabic', '.cs': 'Czech', '.tt': 'ThingTalk', '.fa': 'Farsi'}
2018-06-20 06:22:34 +00:00
source, target = langs[exts[0]], langs[exts[1]]
src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts)
question = f'Translate from {source} to {target}'
examples = []
with open(src_path) as src_file, open(trg_path) as trg_file:
for src_line, trg_line in zip(src_file, trg_file):
src_line, trg_line = src_line.strip(), trg_line.strip()
if src_line != '' and trg_line != '':
context = src_line
answer = trg_line
context_question = get_context_question(context, question)
examples.append(data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields, tokenize=tokenize))
2018-06-20 06:22:34 +00:00
if subsample is not None and len(examples) >= subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
2018-06-20 06:22:34 +00:00
torch.save(examples, cache_name)
super(translation.TranslationDataset, self).__init__(examples, fields, **kwargs)
class Multi30k(TranslationDataset, CQA, translation.Multi30k):
pass
class IWSLT(TranslationDataset, CQA, translation.IWSLT):
pass
class SQuAD(CQA, data.Dataset):
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
urls = ['https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json',
2018-08-25 00:53:01 +00:00
'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json',
'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json',
'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json',]
2018-06-20 06:22:34 +00:00
name = 'squad'
dirname = ''
def __init__(self, path, field, subsample=None, **kwargs):
fields = [(x, field) for x in self.fields]
cached_path = kwargs.pop('cached_path')
2019-02-20 00:21:34 +00:00
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
2018-06-20 06:22:34 +00:00
2018-09-27 20:08:55 +00:00
examples, all_answers, q_ids = [], [], []
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool:
logger.info(f'Loading cached data from {cache_name}')
2018-09-27 20:08:55 +00:00
examples, all_answers, q_ids = torch.load(cache_name)
2018-06-20 06:22:34 +00:00
else:
with open(os.path.expanduser(path)) as f:
squad = json.load(f)['data']
for document in squad:
title = document['title']
paragraphs = document['paragraphs']
for paragraph in paragraphs:
context = paragraph['context']
qas = paragraph['qas']
for qa in qas:
question = ' '.join(qa['question'].split())
2018-09-27 20:08:55 +00:00
q_ids.append(qa['id'])
2018-06-20 06:22:34 +00:00
squad_id = len(all_answers)
context_question = get_context_question(context, question)
2018-08-25 00:53:01 +00:00
if len(qa['answers']) == 0:
answer = 'unanswerable'
all_answers.append(['unanswerable'])
context = ' '.join(context.split())
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
ex.context_spans = [-1, -1]
ex.answer_start = -1
ex.answer_end = -1
else:
answer = qa['answers'][0]['text']
all_answers.append([a['text'] for a in qa['answers']])
answer_start = qa['answers'][0]['answer_start']
answer_end = answer_start + len(answer)
context_before_answer = context[:answer_start]
context_after_answer = context[answer_end:]
BEGIN = 'beginanswer '
END = ' endanswer'
tagged_context = context_before_answer + BEGIN + answer + END + context_after_answer
ex = data.Example.fromlist([tagged_context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
tokenized_answer = ex.answer
for xi, x in enumerate(ex.context):
if BEGIN in x:
answer_start = xi + 1
ex.context[xi] = x.replace(BEGIN, '')
if END in x:
answer_end = xi
ex.context[xi] = x.replace(END, '')
new_context = []
original_answer_start = answer_start
original_answer_end = answer_end
indexed_with_spaces = ex.context[answer_start:answer_end]
if len(indexed_with_spaces) != len(tokenized_answer):
import pdb; pdb.set_trace()
# remove spaces
for xi, x in enumerate(ex.context):
if len(x.strip()) == 0:
if xi <= original_answer_start:
answer_start -= 1
if xi < original_answer_end:
answer_end -= 1
else:
new_context.append(x)
ex.context = new_context
ex.answer = [x for x in ex.answer if len(x.strip()) > 0]
if len(ex.context[answer_start:answer_end]) != len(ex.answer):
2018-06-20 06:22:34 +00:00
import pdb; pdb.set_trace()
2018-08-25 00:53:01 +00:00
ex.context_spans = list(range(answer_start, answer_end))
indexed_answer = ex.context[ex.context_spans[0]:ex.context_spans[-1]+1]
if len(indexed_answer) != len(ex.answer):
import pdb; pdb.set_trace()
if field.eos_token is not None:
ex.context_spans += [len(ex.context)]
for context_idx, answer_word in zip(ex.context_spans, ex.answer):
if context_idx == len(ex.context):
continue
if ex.context[context_idx] != answer_word:
import pdb; pdb.set_trace()
ex.answer_start = ex.context_spans[0]
ex.answer_end = ex.context_spans[-1]
2018-06-20 06:22:34 +00:00
ex.squad_id = squad_id
examples.append(ex)
if subsample is not None and len(examples) > subsample:
break
if subsample is not None and len(examples) > subsample:
break
if subsample is not None and len(examples) > subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
2018-09-27 20:08:55 +00:00
torch.save((examples, all_answers, q_ids), cache_name)
2018-06-20 06:22:34 +00:00
FIELD = data.Field(batch_first=True, use_vocab=False, sequential=False,
lower=False, numerical=True, eos_token=field.eos_token, init_token=field.init_token)
fields.append(('context_spans', FIELD))
fields.append(('answer_start', FIELD))
fields.append(('answer_end', FIELD))
fields.append(('squad_id', FIELD))
super(SQuAD, self).__init__(examples, fields, **kwargs)
self.all_answers = all_answers
2018-09-27 20:08:55 +00:00
self.q_ids = q_ids
2018-06-20 06:22:34 +00:00
@classmethod
2018-08-25 00:53:01 +00:00
def splits(cls, fields, root='.data', description='squad1.1',
2018-06-20 06:22:34 +00:00
train='train', validation='dev', test=None, **kwargs):
"""Create dataset objects for splits of the SQuAD dataset.
Arguments:
root: directory containing SQuAD data
field: field for handling all columns
train: The prefix of the train data. Default: 'train'.
validation: The prefix of the validation data. Default: 'val'.
Remaining keyword arguments: Passed to the splits method of
Dataset.
"""
assert test is None
path = cls.download(root)
2019-03-13 00:06:01 +00:00
extension = 'v2.0.json' if '2.0' in description else 'v1.1.json'
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux = '-'.join(['aux', extension])
aux_data = cls(os.path.join(path, aux), fields, **kwargs)
2018-06-20 06:22:34 +00:00
train = '-'.join([train, extension]) if train is not None else None
validation = '-'.join([validation, extension]) if validation is not None else None
train_data = None if train is None else cls(
os.path.join(path, train), fields, **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, validation), fields, **kwargs)
return tuple(d for d in (train_data, validation_data, aux_data)
2018-06-20 06:22:34 +00:00
if d is not None)
# https://github.com/abisee/cnn-dailymail/blob/8eace60f306dcbab30d1f1d715e379f07a3782db/make_datafiles.py
dm_single_close_quote = u'\u2019'
dm_double_close_quote = u'\u201d'
END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence
def fix_missing_period(line):
"""Adds a period to a line that is missing a period"""
if "@highlight" in line: return line
if line=="": return line
if line[-1] in END_TOKENS: return line
return line + "."
class Summarization(CQA, data.Dataset):
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
def __init__(self, path, field, one_answer=True, subsample=None, **kwargs):
fields = [(x, field) for x in self.fields]
cached_path = kwargs.pop('cached_path')
2019-02-20 00:21:34 +00:00
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
2018-06-20 06:22:34 +00:00
examples = []
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool:
logger.info(f'Loading cached data from {cache_name}')
2018-06-20 06:22:34 +00:00
examples = torch.load(cache_name)
else:
with open(os.path.expanduser(path)) as f:
lines = f.readlines()
for line in lines:
ex = json.loads(line)
context, question, answer = ex['context'], ex['question'], ex['answer']
context_question = get_context_question(context, question)
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
examples.append(ex)
if subsample is not None and len(examples) >= subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
2018-06-20 06:22:34 +00:00
torch.save(examples, cache_name)
super(Summarization, self).__init__(examples, fields, **kwargs)
@classmethod
def cache_splits(cls, path):
splits = ['training', 'validation', 'test']
for split in splits:
2018-06-20 06:22:34 +00:00
missing_stories, collected_stories = 0, 0
split_file_name = os.path.join(path, f'{split}.jsonl')
if os.path.exists(split_file_name):
continue
with open(split_file_name, 'w') as split_file:
url_file_name = os.path.join(path, f'{cls.name}_wayback_{split}_urls.txt')
with open(url_file_name) as url_file:
for url in url_file:
story_file_name = os.path.join(path, 'stories',
f"{hashlib.sha1(url.strip().encode('utf-8')).hexdigest()}.story")
try:
story_file = open(story_file_name)
except EnvironmentError as e:
missing_stories += 1
logger.warning(e)
2018-06-20 06:22:34 +00:00
if os.path.exists(split_file_name):
os.remove(split_file_name)
else:
with story_file:
article, highlight = [], []
is_highlight = False
for line in story_file:
line = line.strip()
if line == "":
continue
line = fix_missing_period(line)
if line.startswith("@highlight"):
is_highlight = True
elif "@highlight" in line:
raise
elif is_highlight:
highlight.append(line)
else:
article.append(line)
example = {'context': unicodedata.normalize('NFKC', ' '.join(article)),
'answer': unicodedata.normalize('NFKC', ' '.join(highlight)),
'question': 'What is the summary?'}
split_file.write(json.dumps(example)+'\n')
collected_stories += 1
if collected_stories % 1000 == 0:
logger.debug(example)
logger.warning(f'Missing {missing_stories} stories')
logger.info(f'Collected {collected_stories} stories')
2018-06-20 06:22:34 +00:00
@classmethod
def splits(cls, fields, root='.data',
train='training', validation='validation', test='test', **kwargs):
path = cls.download(root)
cls.cache_splits(path)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'auxiliary.jsonl'), fields, **kwargs)
2018-06-20 06:22:34 +00:00
train_data = None if train is None else cls(
os.path.join(path, 'training.jsonl'), fields, **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, 'validation.jsonl'), fields, one_answer=False, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, 'test.jsonl'), fields, one_answer=False, **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
2018-06-20 06:22:34 +00:00
if d is not None)
class DailyMail(Summarization):
name = 'dailymail'
dirname = 'dailymail'
urls = [('https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfM1BxdkxVaTY2bWs', 'dailymail_stories.tgz'),
('https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_training_urls.txt', 'dailymail/dailymail_wayback_training_urls.txt'),
('https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_validation_urls.txt', 'dailymail/dailymail_wayback_validation_urls.txt'),
('https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/dailymail_wayback_test_urls.txt', 'dailymail/dailymail_wayback_test_urls.txt')]
class CNN(Summarization):
name = 'cnn'
dirname = 'cnn'
urls = [('https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTHk4NFg2SndKcjQ', 'cnn_stories.tgz'),
('https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_training_urls.txt', 'cnn/cnn_wayback_training_urls.txt'),
('https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_validation_urls.txt', 'cnn/cnn_wayback_validation_urls.txt'),
('https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/cnn_wayback_test_urls.txt', 'cnn/cnn_wayback_test_urls.txt')]
class Query:
#https://github.com/salesforce/WikiSQL/blob/c2ed4f9b22db1cc2721805d53e6e76e07e2ccbdc/lib/query.py#L10
agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
cond_ops = ['=', '>', '<', 'OP']
syms = ['SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION', 'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS']
def __init__(self, sel_index, agg_index, columns, conditions=tuple()):
self.sel_index = sel_index
self.agg_index = agg_index
self.columns = columns
self.conditions = list(conditions)
def __repr__(self):
rep = 'SELECT {agg} {sel} FROM table'.format(
agg=self.agg_ops[self.agg_index],
sel= self.columns[self.sel_index] if self.columns is not None else 'col{}'.format(self.sel_index),
)
if self.conditions:
rep += ' WHERE ' + ' AND '.join(['{} {} {}'.format(self.columns[i], self.cond_ops[o], v) for i, o, v in self.conditions])
2018-06-20 06:22:34 +00:00
return ' '.join(rep.split())
@classmethod
def from_dict(cls, d, t):
return cls(sel_index=d['sel'], agg_index=d['agg'], columns=t, conditions=d['conds'])
class WikiSQL(CQA, data.Dataset):
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
urls = ['https://github.com/salesforce/WikiSQL/raw/master/data.tar.bz2']
name = 'wikisql'
dirname = 'data'
2018-09-07 00:18:00 +00:00
def __init__(self, path, field, query_as_question=False, subsample=None, **kwargs):
2018-06-20 06:22:34 +00:00
fields = [(x, field) for x in self.fields]
FIELD = data.Field(batch_first=True, use_vocab=False, sequential=False,
lower=False, numerical=True, eos_token=field.eos_token, init_token=field.init_token)
fields.append(('wikisql_id', FIELD))
2019-03-01 23:08:31 +00:00
cached_path = kwargs.pop('cached_path')
2019-02-20 00:21:34 +00:00
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', 'query_as_question' if query_as_question else 'query_as_context', os.path.basename(path), str(subsample))
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool:
logger.info(f'Loading cached data from {cache_name}')
2018-06-20 06:22:34 +00:00
examples, all_answers = torch.load(cache_name)
else:
expanded_path = os.path.expanduser(path)
table_path = os.path.splitext(expanded_path)
table_path = table_path[0] + '.tables' + table_path[1]
with open(table_path) as tables_file:
tables = [json.loads(line) for line in tables_file]
id_to_tables = {x['id']: x for x in tables}
all_answers = []
examples = []
with open(expanded_path) as example_file:
for idx, line in enumerate(example_file):
entry = json.loads(line)
2018-09-07 00:18:00 +00:00
human_query = entry['question']
2018-06-20 06:22:34 +00:00
table = id_to_tables[entry['table_id']]
sql = entry['sql']
header = table['header']
answer = repr(Query.from_dict(sql, header))
2018-09-07 00:18:00 +00:00
context = (f'The table has columns {", ".join(table["header"])} ' +
f'and key words {", ".join(Query.agg_ops[1:] + Query.cond_ops + Query.syms)}')
if query_as_question:
question = human_query
else:
question = 'What is the translation from English to SQL?'
context += f'-- {human_query}'
2018-06-20 06:22:34 +00:00
context_question = get_context_question(context, question)
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question, idx], fields)
examples.append(ex)
all_answers.append({'sql': sql, 'header': header, 'answer': answer, 'table': table})
if subsample is not None and len(examples) > subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
2018-06-20 06:22:34 +00:00
torch.save((examples, all_answers), cache_name)
super(WikiSQL, self).__init__(examples, fields, **kwargs)
self.all_answers = all_answers
@classmethod
def splits(cls, fields, root='.data',
train='train.jsonl', validation='dev.jsonl', test='test.jsonl', **kwargs):
"""Create dataset objects for splits of the SQuAD dataset.
Arguments:
root: directory containing SQuAD data
field: field for handling all columns
train: The prefix of the train data. Default: 'train'.
validation: The prefix of the validation data. Default: 'val'.
Remaining keyword arguments: Passed to the splits method of
Dataset.
"""
path = cls.download(root)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux'), fields, **kwargs)
2018-06-20 06:22:34 +00:00
train_data = None if train is None else cls(
os.path.join(path, train), fields, **kwargs)
validation_data = None if validation is None else cls(
2018-09-07 00:18:00 +00:00
os.path.join(path, validation), fields, **kwargs)
2018-06-20 06:22:34 +00:00
test_data = None if test is None else cls(
2018-09-07 00:18:00 +00:00
os.path.join(path, test), fields, **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
2018-06-20 06:22:34 +00:00
if d is not None)
class SRL(CQA, data.Dataset):
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
urls = ['https://dada.cs.washington.edu/qasrl/data/wiki1.train.qa',
'https://dada.cs.washington.edu/qasrl/data/wiki1.dev.qa',
'https://dada.cs.washington.edu/qasrl/data/wiki1.test.qa']
name = 'srl'
dirname = ''
@classmethod
def clean(cls, s):
closing_punctuation = set([ ' .', ' ,', ' ;', ' !', ' ?', ' :', ' )', " 'll", " n't ", " %", " 't", " 's", " 'm", " 'd", " 're"])
opening_punctuation = set(['( ', '$ '])
both_sides = set([' - '])
s = ' '.join(s.split()).strip()
s = s.replace('-LRB-', '(')
s = s.replace('-RRB-', ')')
s = s.replace('-LAB-', '<')
s = s.replace('-RAB-', '>')
s = s.replace('-AMP-', '&')
s = s.replace('%pw', ' ')
for p in closing_punctuation:
s = s.replace(p, p.lstrip())
for p in opening_punctuation:
s = s.replace(p, p.rstrip())
for p in both_sides:
s = s.replace(p, p.strip())
s = s.replace('``', '')
s = s.replace('`', '')
s = s.replace("''", '')
s = s.replace('', '')
s = s.replace('', '')
s = s.replace(" '", '')
return ' '.join(s.split()).strip()
def __init__(self, path, field, one_answer=True, subsample=None, **kwargs):
fields = [(x, field) for x in self.fields]
cached_path = kwargs.pop('cached_path')
2019-02-20 00:21:34 +00:00
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
2018-06-20 06:22:34 +00:00
examples, all_answers = [], []
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool:
logger.info(f'Loading cached data from {cache_name}')
2018-06-20 06:22:34 +00:00
examples, all_answers = torch.load(cache_name)
else:
with open(os.path.expanduser(path)) as f:
for line in f:
ex = json.loads(line)
t = ex['type']
aa = ex['all_answers']
context, question, answer = ex['context'], ex['question'], ex['answer']
context_question = get_context_question(context, question)
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
examples.append(ex)
ex.squad_id = len(all_answers)
all_answers.append(aa)
if subsample is not None and len(examples) >= subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
2018-06-20 06:22:34 +00:00
torch.save((examples, all_answers), cache_name)
FIELD = data.Field(batch_first=True, use_vocab=False, sequential=False,
lower=False, numerical=True, eos_token=field.eos_token, init_token=field.init_token)
fields.append(('squad_id', FIELD))
super(SRL, self).__init__(examples, fields, **kwargs)
self.all_answers = all_answers
@classmethod
def cache_splits(cls, path, train='train', validation='dev', test='test'):
2018-06-20 06:22:34 +00:00
splits = [train, validation, test]
for split in splits:
2018-06-20 06:22:34 +00:00
split_file_name = os.path.join(path, f'{split}.jsonl')
if os.path.exists(split_file_name):
continue
wiki_file = os.path.join(path, f'wiki1.{split}.qa')
with open(split_file_name, 'w') as split_file:
with open(os.path.expanduser(wiki_file)) as f:
def is_int(x):
try:
int(x)
return True
except:
return False
lines = []
for line in f.readlines():
line = ' '.join(line.split()).strip()
if len(line) == 0:
lines.append(line)
continue
if not 'WIKI1' in line.split('_')[0]:
if not is_int(line.split()[0]) or len(line.split()) > 3:
lines.append(line)
new_example = True
for line in lines:
line = line.strip()
if new_example:
context = cls.clean(line)
new_example = False
continue
if len(line) == 0:
new_example = True
continue
question, answers = line.split('?')
question = cls.clean(line.split('?')[0].replace(' _', '') +'?')
answer = cls.clean(answers.split('###')[0])
all_answers = [cls.clean(x) for x in answers.split('###')]
if answer not in context:
low_answer = answer[0].lower() + answer[1:]
up_answer = answer[0].upper() + answer[1:]
if low_answer in context or up_answer in context:
answer = low_answer if low_answer in context else up_answer
else:
if 'Darcy Burner' in answer:
answer = 'Darcy Burner and other 2008 Democratic congressional candidates, in cooperation with some retired national security officials'
elif 'E Street Band' in answer:
answer = 'plan to work with the E Street Band again in the future'
elif 'an electric sender' in answer:
answer = 'an electronic sender'
elif 'the US army' in answer:
answer = 'the US Army'
elif 'Rather than name the' in answer:
answer = 'rather die than name the cause of his disease to his father'
elif answer.lower() in context:
answer = answer.lower()
else:
import pdb; pdb.set_trace()
assert answer in context
modified_all_answers = []
for a in all_answers:
if a not in context:
low_answer = a[0].lower() + a[1:]
up_answer = a[0].upper() + a[1:]
if low_answer in context or up_answer in context:
a = low_answer if low_answer in context else up_answer
else:
if 'Darcy Burner' in a:
a = 'Darcy Burner and other 2008 Democratic congressional candidates, in cooperation with some retired national security officials'
elif 'E Street Band' in a:
a = 'plan to work with the E Street Band again in the future'
elif 'an electric sender' in a:
a = 'an electronic sender'
elif 'the US army' in a:
a = 'the US Army'
elif 'Rather than name the' in a:
a = 'rather die than name the cause of his disease to his father'
elif a.lower() in context:
a = a.lower()
else:
import pdb; pdb.set_trace()
assert a in context
modified_all_answers.append(a)
split_file.write(json.dumps({'context': context, 'question': question, 'answer': answer, 'type': 'wiki', 'all_answers': modified_all_answers})+'\n')
@classmethod
def splits(cls, fields, root='.data',
train='train', validation='dev', test='test', **kwargs):
path = cls.download(root)
cls.cache_splits(path)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
2018-06-20 06:22:34 +00:00
train_data = None if train is None else cls(
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, f'{validation}.jsonl'), fields, one_answer=False, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, f'{test}.jsonl'), fields, one_answer=False, **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
2018-06-20 06:22:34 +00:00
if d is not None)
class WinogradSchema(CQA, data.Dataset):
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
urls = ['https://s3.amazonaws.com/research.metamind.io/decaNLP/data/schema.txt']
2018-06-20 06:22:34 +00:00
name = 'schema'
dirname = ''
def __init__(self, path, field, subsample=None, **kwargs):
fields = [(x, field) for x in self.fields]
cached_path = kwargs.pop('cached_path')
2019-02-20 00:21:34 +00:00
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool:
logger.info(f'Loading cached data from {cache_name}')
2018-06-20 06:22:34 +00:00
examples = torch.load(cache_name)
else:
examples = []
with open(os.path.expanduser(path)) as f:
for line in f:
ex = json.loads(line)
context, question, answer = ex['context'], ex['question'], ex['answer']
context_question = get_context_question(context, question)
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
examples.append(ex)
if subsample is not None and len(examples) >= subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
2018-06-20 06:22:34 +00:00
torch.save(examples, cache_name)
super(WinogradSchema, self).__init__(examples, fields, **kwargs)
@classmethod
def cache_splits(cls, path):
pattern = '\[.*\]'
train_jsonl = os.path.expanduser(os.path.join(path, 'train.jsonl'))
if os.path.exists(train_jsonl):
return
def get_both_schema(context):
variations = [x[1:-1].split('/') for x in re.findall(pattern, context)]
splits = re.split(pattern, context)
results = []
for which_schema in range(2):
vs = [v[which_schema] for v in variations]
context = ''
for idx in range(len(splits)):
context += splits[idx]
if idx < len(vs):
context += vs[idx]
results.append(context)
return results
schemas = []
with open(os.path.expanduser(os.path.join(path, 'schema.txt'))) as schema_file:
schema = []
for line in schema_file:
if len(line.split()) == 0:
schemas.append(schema)
schema = []
continue
else:
schema.append(line.strip())
examples = []
for schema in schemas:
context, question, answer = schema
contexts = get_both_schema(context)
questions = get_both_schema(question)
answers = answer.split('/')
for idx in range(2):
answer = answers[idx]
question = questions[idx] + f' {answers[0]} or {answers[1]}?'
examples.append({'context': contexts[idx], 'question': question, 'answer': answer})
traindev = examples[:-100]
test = examples[-100:]
train = traindev[:80]
dev = traindev[80:]
splits = ['train', 'validation', 'test']
for split, examples in zip(splits, [train, dev, test]):
with open(os.path.expanduser(os.path.join(path, f'{split}.jsonl')), 'a') as split_file:
for ex in examples:
split_file.write(json.dumps(ex)+'\n')
@classmethod
def splits(cls, fields, root='.data',
train='train', validation='validation', test='test', **kwargs):
path = cls.download(root)
cls.cache_splits(path)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
2018-06-20 06:22:34 +00:00
train_data = None if train is None else cls(
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
2018-06-20 06:22:34 +00:00
if d is not None)
class WOZ(CQA, data.Dataset):
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
urls = ['https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_train_en.json',
'https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_test_de.json',
'https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_test_en.json',
'https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_train_de.json',
'https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_train_en.json',
'https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_validate_de.json',
'https://raw.githubusercontent.com/nmrksic/neural-belief-tracker/master/data/woz/woz_validate_en.json']
name = 'woz'
dirname = ''
def __init__(self, path, field, subsample=None, description='woz.en', **kwargs):
fields = [(x, field) for x in self.fields]
FIELD = data.Field(batch_first=True, use_vocab=False, sequential=False,
lower=False, numerical=True, eos_token=field.eos_token, init_token=field.init_token)
fields.append(('woz_id', FIELD))
examples, all_answers = [], []
cached_path = kwargs.pop('cached_path')
2019-02-20 00:21:34 +00:00
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample), description)
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool:
logger.info(f'Loading cached data from {cache_name}')
2018-06-20 06:22:34 +00:00
examples, all_answers = torch.load(cache_name)
else:
with open(os.path.expanduser(path)) as f:
for woz_id, line in enumerate(f):
ex = example_dict = json.loads(line)
if example_dict['lang'] in description:
context, question, answer = ex['context'], ex['question'], ex['answer']
context_question = get_context_question(context, question)
all_answers.append((ex['lang_dialogue_turn'], answer))
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question, woz_id], fields)
examples.append(ex)
if subsample is not None and len(examples) >= subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
2018-06-20 06:22:34 +00:00
torch.save((examples, all_answers), cache_name)
super(WOZ, self).__init__(examples, fields, **kwargs)
self.all_answers = all_answers
@classmethod
def cache_splits(cls, path, train='train', validation='validate', test='test'):
train_jsonl = os.path.expanduser(os.path.join(path, 'train.jsonl'))
if os.path.exists(train_jsonl):
return
2019-03-13 00:06:01 +00:00
splits = [train, validation, test]
2018-06-20 06:22:34 +00:00
file_name_base = 'woz_{}_{}.json'
question_base = "What is the change in state"
for split in [train, validation, test]:
with open (os.path.expanduser(os.path.join(path, f'{split}.jsonl')), 'a') as split_file:
for lang in ['en', 'de']:
file_path = file_name_base.format(split, lang)
with open(os.path.expanduser(os.path.join(path, file_path))) as src_file:
dialogues = json.loads(src_file.read())
for di, d in enumerate(dialogues):
previous_state = {'inform': [], 'request': []}
turns = d['dialogue']
for ti, t in enumerate(turns):
question = 'What is the change in state?'
actions = []
for act in t['system_acts']:
if isinstance(act, list):
act = ': '.join(act)
actions.append(act)
actions = ', '.join(actions)
if len(actions) > 0:
actions += ' -- '
context = actions + t['transcript']
belief_state = t['belief_state']
delta_state = {'inform': [], 'request': []}
current_state = {'inform': [], 'request': []}
for item in belief_state:
if 'slots' in item:
slots = item['slots']
for slot in slots:
act = item['act']
if act == 'inform':
current_state['inform'].append(slot)
if not slot in previous_state['inform']:
delta_state['inform'].append(slot)
else:
prev_slot = previous_state['inform'][previous_state['inform'].index(slot)]
if prev_slot[1] != slot[1]:
delta_state['inform'].append(slot)
else:
delta_state['request'].append(slot[1])
current_state['request'].append(slot[1])
previous_state = current_state
2018-06-27 20:51:14 +00:00
answer = ''
if len(delta_state['inform']) > 0:
answer = ', '.join([f'{x[0]}: {x[1]}' for x in delta_state['inform']])
answer += ';'
2018-06-20 06:22:34 +00:00
if len(delta_state['request']) > 0:
2018-06-27 18:52:02 +00:00
answer += ' '
2018-06-20 06:22:34 +00:00
answer += ', '.join(delta_state['request'])
ex = {'context': ' '.join(context.split()),
'question': ' '.join(question.split()), 'lang': lang,
2018-06-27 20:51:14 +00:00
'answer': answer if len(answer) > 1 else 'None',
2018-06-20 06:22:34 +00:00
'lang_dialogue_turn': f'{lang}_{di}_{ti}'}
split_file.write(json.dumps(ex)+'\n')
@classmethod
def splits(cls, fields, root='.data', train='train', validation='validate', test='test', **kwargs):
path = cls.download(root)
cls.cache_splits(path)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
2018-06-20 06:22:34 +00:00
train_data = None if train is None else cls(
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
2018-06-20 06:22:34 +00:00
if d is not None)
class MultiNLI(CQA, data.Dataset):
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
urls = ['http://www.nyu.edu/projects/bowman/multinli/multinli_1.0.zip']
name = 'multinli'
dirname = 'multinli_1.0'
def __init__(self, path, field, subsample=None, description='multinli.in.out', **kwargs):
fields = [(x, field) for x in self.fields]
cached_path = kwargs.pop('cached_path')
2019-02-20 00:21:34 +00:00
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample), description)
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool:
logger.info(f'Loading cached data from {cache_name}')
2018-06-20 06:22:34 +00:00
examples = torch.load(cache_name)
else:
examples = []
with open(os.path.expanduser(path)) as f:
for line in f:
ex = example_dict = json.loads(line)
if example_dict['subtask'] in description:
context, question, answer = ex['context'], ex['question'], ex['answer']
context_question = get_context_question(context, question)
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
examples.append(ex)
if subsample is not None and len(examples) >= subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
2018-06-20 06:22:34 +00:00
torch.save(examples, cache_name)
super(MultiNLI, self).__init__(examples, fields, **kwargs)
@classmethod
def cache_splits(cls, path, train='multinli_1.0_train', validation='mulinli_1.0_dev_{}', test='test'):
train_jsonl = os.path.expanduser(os.path.join(path, 'train.jsonl'))
if os.path.exists(train_jsonl):
return
with open(os.path.expanduser(os.path.join(path, f'train.jsonl')), 'a') as split_file:
with open(os.path.expanduser(os.path.join(path, f'multinli_1.0_train.jsonl'))) as src_file:
for line in src_file:
ex = json.loads(line)
ex = {'context': f'Premise: "{ex["sentence1"]}"',
'question': f'Hypothesis: "{ex["sentence2"]}" -- entailment, neutral, or contradiction?',
'answer': ex['gold_label'],
'subtask': 'multinli'}
split_file.write(json.dumps(ex)+'\n')
with open(os.path.expanduser(os.path.join(path, f'validation.jsonl')), 'a') as split_file:
for subtask in ['matched', 'mismatched']:
with open(os.path.expanduser(os.path.join(path, 'multinli_1.0_dev_{}.jsonl'.format(subtask)))) as src_file:
for line in src_file:
ex = json.loads(line)
ex = {'context': f'Premise: "{ex["sentence1"]}"',
'question': f'Hypothesis: "{ex["sentence2"]}" -- entailment, neutral, or contradiction?',
'answer': ex['gold_label'],
'subtask': 'in' if subtask == 'matched' else 'out'}
split_file.write(json.dumps(ex)+'\n')
@classmethod
def splits(cls, fields, root='.data', train='train', validation='validation', test='test', **kwargs):
path = cls.download(root)
cls.cache_splits(path)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
2018-06-20 06:22:34 +00:00
train_data = None if train is None else cls(
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
2018-06-20 06:22:34 +00:00
if d is not None)
class ZeroShotRE(CQA, data.Dataset):
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
urls = ['http://nlp.cs.washington.edu/zeroshot/relation_splits.tar.bz2']
dirname = 'relation_splits'
name = 'zre'
def __init__(self, path, field, subsample=None, **kwargs):
fields = [(x, field) for x in self.fields]
cached_path = kwargs.pop('cached_path')
2019-02-20 00:21:34 +00:00
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool:
logger.info(f'Loading cached data from {cache_name}')
2018-06-20 06:22:34 +00:00
examples = torch.load(cache_name)
else:
examples = []
with open(os.path.expanduser(path)) as f:
for line in f:
ex = example_dict = json.loads(line)
context, question, answer = ex['context'], ex['question'], ex['answer']
context_question = get_context_question(context, question)
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
examples.append(ex)
if subsample is not None and len(examples) >= subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
2018-06-20 06:22:34 +00:00
torch.save(examples, cache_name)
super().__init__(examples, fields, **kwargs)
@classmethod
def cache_splits(cls, path, train='train', validation='dev', test='test'):
train_jsonl = os.path.expanduser(os.path.join(path, f'{train}.jsonl'))
if os.path.exists(train_jsonl):
return
base_file_name = '{}.0'
for split in [train, validation, test]:
src_file_name = base_file_name.format(split)
with open(os.path.expanduser(os.path.join(path, f'{split}.jsonl')), 'a') as split_file:
with open(os.path.expanduser(os.path.join(path, src_file_name))) as src_file:
for line in src_file:
split_line = line.split('\t')
if len(split_line) == 4:
answer = ''
relation, question, subject, context = split_line
else:
relation, question, subject, context = split_line[:4]
answer = ', '.join(split_line[4:])
question = question.replace('XXX', subject)
ex = {'context': context,
'question': question,
2018-08-25 00:53:01 +00:00
'answer': answer if len(answer) > 0 else 'unanswerable'}
2018-06-20 06:22:34 +00:00
split_file.write(json.dumps(ex)+'\n')
@classmethod
def splits(cls, fields, root='.data', train='train', validation='dev', test='test', **kwargs):
path = cls.download(root)
cls.cache_splits(path)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
2018-06-20 06:22:34 +00:00
train_data = None if train is None else cls(
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
2018-06-20 06:22:34 +00:00
if d is not None)
class OntoNotesNER(CQA, data.Dataset):
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
urls = ['http://conll.cemantix.org/2012/download/ids/english/all/train.id',
'http://conll.cemantix.org/2012/download/ids/english/all/development.id',
'http://conll.cemantix.org/2012/download/ids/english/all/test.id']
name = 'ontonotes.ner'
dirname = ''
@classmethod
def clean(cls, s):
closing_punctuation = set([ ' .', ' ,', ' ;', ' !', ' ?', ' :', ' )', " '", " n't ", " %"])
opening_punctuation = set(['( ', '$ '])
both_sides = set([' - '])
s = ' '.join(s.split()).strip()
s = s.replace(' /.', ' .')
s = s.replace(' /?', ' ?')
s = s.replace('-LRB-', '(')
s = s.replace('-RRB-', ')')
s = s.replace('-LAB-', '<')
s = s.replace('-RAB-', '>')
s = s.replace('-AMP-', '&')
s = s.replace('%pw', ' ')
for p in closing_punctuation:
s = s.replace(p, p.lstrip())
for p in opening_punctuation:
s = s.replace(p, p.rstrip())
for p in both_sides:
s = s.replace(p, p.strip())
s = s.replace('``', '"')
s = s.replace("''", '"')
quote_is_open = True
quote_idx = s.find('"')
raw = ''
while quote_idx >= 0:
start_enamex_open_idx = s.find('<ENAMEX')
if start_enamex_open_idx > -1:
end_enamex_open_idx = s.find('">') + 2
if start_enamex_open_idx <= quote_idx <= end_enamex_open_idx:
raw += s[:end_enamex_open_idx]
s = s[end_enamex_open_idx:]
quote_idx = s.find('"')
continue
if quote_is_open:
raw += s[:quote_idx+1]
s = s[quote_idx+1:].strip()
quote_is_open = False
else:
raw += s[:quote_idx].strip() + '"'
s = s[quote_idx+1:]
quote_is_open = True
quote_idx = s.find('"')
raw += s
return ' '.join(raw.split()).strip()
def __init__(self, path, field, one_answer=True, subsample=None, path_to_files='.data/ontonotes-release-5.0/data/files', subtask='all', nones=True, **kwargs):
fields = [(x, field) for x in self.fields]
cached_path = kwargs.pop('cached_path')
2019-02-20 00:21:34 +00:00
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample), subtask, str(nones))
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool:
logger.info(f'Loading cached data from {cache_name}')
2018-06-20 06:22:34 +00:00
examples = torch.load(cache_name)
else:
examples = []
with open(os.path.expanduser(path)) as f:
for line in f:
example_dict = json.loads(line)
t = example_dict['type']
a = example_dict['answer']
if (subtask == 'both' or t == subtask):
if a != 'None' or nones:
ex = example_dict
context, question, answer = ex['context'], ex['question'], ex['answer']
context_question = get_context_question(context, question)
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
examples.append(ex)
if subsample is not None and len(examples) >= subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
2018-06-20 06:22:34 +00:00
torch.save(examples, cache_name)
super(OntoNotesNER, self).__init__(examples, fields, **kwargs)
@classmethod
def cache_splits(cls, path, path_to_files, train='train', validation='development', test='test'):
label_to_answer = {'PERSON': 'person',
'NORP': 'political',
'FAC': 'facility',
'ORG': 'organization',
'GPE': 'geopolitical',
'LOC': 'location',
'PRODUCT': 'product',
'EVENT': 'event',
'WORK_OF_ART': 'artwork',
'LAW': 'legal',
'LANGUAGE': 'language',
'DATE': 'date',
'TIME': 'time',
'PERCENT': 'percentage',
'MONEY': 'monetary',
'QUANTITY': 'quantitative',
'ORDINAL': 'ordinal',
'CARDINAL': 'cardinal'}
pluralize = {'person': 'persons', 'political': 'political', 'facility': 'facilities', 'organization': 'organizations',
'geopolitical': 'geopolitical', 'location': 'locations', 'product': 'products', 'event': 'events',
'artwork': 'artworks', 'legal': 'legal', 'language': 'languages', 'date': 'dates', 'time': 'times',
'percentage': 'percentages', 'monetary': 'monetary', 'quantitative': 'quantitative', 'ordinal': 'ordinal',
'cardinal': 'cardinal'}
for split in [train, validation, test]:
split_file_name = os.path.join(path, f'{split}.jsonl')
if os.path.exists(split_file_name):
continue
id_file = os.path.join(path, f'{split}.id')
num_file_ids = 0
examples = []
with open(split_file_name, 'w') as split_file:
with open(os.path.expanduser(id_file)) as f:
for file_id in f:
example_file_name = os.path.join(os.path.expanduser(path_to_files), file_id.strip()) + '.name'
if not os.path.exists(example_file_name) or 'annotations/tc/ch' in example_file_name:
continue
num_file_ids += 1
with open(example_file_name) as example_file:
lines = [x.strip() for x in example_file.readlines() if 'DOC' not in x]
for line in lines:
original = line
line = cls.clean(line)
entities = []
while True:
start_enamex_open_idx = line.find('<ENAMEX')
if start_enamex_open_idx == -1:
break
end_enamex_open_idx = line.find('">') + 2
start_enamex_close_idx = line.find('</ENAMEX>')
end_enamex_close_idx = start_enamex_close_idx + len('</ENAMEX>')
enamex_open_tag = line[start_enamex_open_idx:end_enamex_open_idx]
enamex_close_tag = line[start_enamex_close_idx:end_enamex_close_idx]
before_entity = line[:start_enamex_open_idx]
entity = line[end_enamex_open_idx:start_enamex_close_idx]
after_entity = line[end_enamex_close_idx:]
if 'S_OFF' in enamex_open_tag:
s_off_start = enamex_open_tag.find('S_OFF="')
s_off_end = enamex_open_tag.find('">') if 'E_OFF' not in enamex_open_tag else enamex_open_tag.find('" E_OFF')
s_off = int(enamex_open_tag[s_off_start+len('S_OFF="'):s_off_end])
enamex_open_tag = enamex_open_tag[:s_off_start-2] + '">'
before_entity += entity[:s_off]
entity = entity[s_off:]
if 'E_OFF' in enamex_open_tag:
s_off_start = enamex_open_tag.find('E_OFF="')
s_off_end = enamex_open_tag.find('">')
s_off = int(enamex_open_tag[s_off_start+len('E_OFF="'):s_off_end])
enamex_open_tag = enamex_open_tag[:s_off_start-2] + '">'
after_entity = entity[-s_off:] + after_entity
entity = entity[:-s_off]
label_start = enamex_open_tag.find('TYPE="') + len('TYPE="')
label_end = enamex_open_tag.find('">')
label = enamex_open_tag[label_start:label_end]
assert label in label_to_answer
offsets = (len(before_entity), len(before_entity) + len(entity))
entities.append({'entity': entity, 'char_offsets': offsets, 'label': label})
line = before_entity + entity + after_entity
context = line.strip()
is_no_good = False
for entity_tuple in entities:
entity = entity_tuple['entity']
start, end = entity_tuple['char_offsets']
if not context[start:end] == entity:
is_no_good = True
break
if is_no_good:
logger.warning('Throwing out example that looks poorly labeled: ', original.strip(), ' (', file_id.strip(), ')')
2018-06-20 06:22:34 +00:00
continue
question = 'What are the tags for all entities?'
answer = '; '.join([f'{x["entity"]} -- {label_to_answer[x["label"]]}' for x in entities])
if len(answer) == 0:
answer = 'None'
split_file.write(json.dumps({'context': context, 'question': question, 'answer': answer, 'file_id': file_id.strip(),
'original': original.strip(), 'entity_list': entities, 'type': 'all'})+'\n')
partial_question = 'Which entities are {}?'
for lab, ans in label_to_answer.items():
question = partial_question.format(pluralize[ans])
entity_of_type_lab = [x['entity'] for x in entities if x['label'] == lab]
answer = ', '.join(entity_of_type_lab)
if len(answer) == 0:
answer = 'None'
split_file.write(json.dumps({'context': context,
'question': question,
'answer': answer,
'file_id': file_id.strip(),
'original': original.strip(),
'entity_list': entities,
'type': 'one',
})+'\n')
@classmethod
def splits(cls, fields, root='.data',
train='train', validation='development', test='test', **kwargs):
path_to_files = os.path.join(root, 'ontonotes-release-5.0', 'data', 'files')
assert os.path.exists(path_to_files)
path = cls.download(root)
cls.cache_splits(path, path_to_files)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
2018-06-20 06:22:34 +00:00
train_data = None if train is None else cls(
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, f'{validation}.jsonl'), fields, one_answer=False, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, f'{test}.jsonl'), fields, one_answer=False, **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
2018-06-20 06:22:34 +00:00
if d is not None)
class SNLI(CQA, data.Dataset):
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
urls = ['http://nlp.stanford.edu/projects/snli/snli_1.0.zip']
dirname = 'snli_1.0'
name = 'snli'
def __init__(self, path, field, subsample=None, **kwargs):
fields = [(x, field) for x in self.fields]
cached_path = kwargs.pop('cached_path')
2019-02-20 00:21:34 +00:00
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool:
logger.info(f'Loading cached data from {cache_name}')
2018-06-20 06:22:34 +00:00
examples = torch.load(cache_name)
else:
examples = []
with open(os.path.expanduser(path)) as f:
for line in f:
example_dict = json.loads(line)
ex = example_dict
context, question, answer = ex['context'], ex['question'], ex['answer']
context_question = get_context_question(context, question)
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
examples.append(ex)
if subsample is not None and len(examples) >= subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
2018-06-20 06:22:34 +00:00
torch.save(examples, cache_name)
super().__init__(examples, fields, **kwargs)
@classmethod
def cache_splits(cls, path, train='train', validation='dev', test='test'):
train_jsonl = os.path.expanduser(os.path.join(path, f'{train}.jsonl'))
if os.path.exists(train_jsonl):
return
base_file_name = 'snli_1.0_{}.jsonl'
for split in [train, validation, test]:
src_file_name = base_file_name.format(split)
with open(os.path.expanduser(os.path.join(path, f'{split}.jsonl')), 'a') as split_file:
with open(os.path.expanduser(os.path.join(path, src_file_name))) as src_file:
for line in src_file:
ex = json.loads(line)
ex = {'context': f'Premise: "{ex["sentence1"]}"',
'question': f'Hypothesis: "{ex["sentence2"]}" -- entailment, neutral, or contradiction?',
'answer': ex['gold_label']}
split_file.write(json.dumps(ex)+'\n')
@classmethod
def splits(cls, fields, root='.data', train='train', validation='dev', test='test', **kwargs):
path = cls.download(root)
cls.cache_splits(path)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
2018-06-20 06:22:34 +00:00
train_data = None if train is None else cls(
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
2018-06-20 06:22:34 +00:00
if d is not None)
2018-08-16 19:42:37 +00:00
class JSON(CQA, data.Dataset):
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
def __init__(self, path, field, subsample=None, **kwargs):
fields = [(x, field) for x in self.fields]
cached_path = kwargs.pop('cached_path')
2019-02-20 00:21:34 +00:00
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
2018-08-16 19:42:37 +00:00
examples = []
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool:
logger.info(f'Loading cached data from {cache_name}')
2018-08-16 19:42:37 +00:00
examples = torch.load(cache_name)
else:
with open(os.path.expanduser(path)) as f:
lines = f.readlines()
for line in lines:
ex = json.loads(line)
context, question, answer = ex['context'], ex['question'], ex['answer']
context_question = get_context_question(context, question)
ex = data.Example.fromlist([context, question, answer, CONTEXT_SPECIAL, QUESTION_SPECIAL, context_question], fields)
examples.append(ex)
if subsample is not None and len(examples) >= subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
2018-08-16 19:42:37 +00:00
torch.save(examples, cache_name)
super(JSON, self).__init__(examples, fields, **kwargs)
@classmethod
def splits(cls, fields, name, root='.data',
train='train', validation='val', test='test', **kwargs):
path = os.path.join(root, name)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
2018-08-16 19:42:37 +00:00
train_data = None if train is None else cls(
os.path.join(path, 'train.jsonl'), fields, **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, 'val.jsonl'), fields, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, 'test.jsonl'), fields, **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
2018-08-16 19:42:37 +00:00
if d is not None)