Remove all usage of Field from tasks and datasets

This commit is contained in:
Giovanni Campagna 2019-12-21 07:36:28 -08:00
parent aa628629bf
commit ea83b75089
8 changed files with 220 additions and 243 deletions

View File

@ -41,6 +41,8 @@ class Example(NamedTuple):
question : List[str]
answer : List[str]
vocab_fields = ['context', 'question', 'answer']
@staticmethod
def from_raw(example_id : str, context : str, question : str, answer : str, tokenize, lower=False):
args = [[example_id]]

View File

@ -43,11 +43,12 @@ from .metrics import compute_metrics
from .utils.embeddings import load_embeddings
from .tasks.registry import get_tasks
from . import models
from .text.data import Iterator, ReversibleField
from .data.example import Example
from .text.data import ReversibleField
logger = logging.getLogger(__name__)
def get_all_splits(args, new_vocab):
def get_all_splits(args, new_field):
splits = []
for task in args.tasks:
logger.info(f'Loading {task}')
@ -62,8 +63,8 @@ def get_all_splits(args, new_vocab):
kwargs['skip_cache_bool'] = args.skip_cache_bool
kwargs['cached_path'] = args.cached
kwargs['subsample'] = args.subsample
s = task.get_splits(new_vocab, root=args.data, **kwargs)[0]
preprocess_examples(args, [task], [s], new_vocab, train=False)
s = task.get_splits(root=args.data, tokenize=new_field.tokenize, lower=args.lower, **kwargs)[0]
preprocess_examples(args, [task], [s], new_field, train=False)
splits.append(s)
return splits
@ -71,7 +72,7 @@ def get_all_splits(args, new_vocab):
def prepare_data(args, FIELD):
new_vocab = ReversibleField(batch_first=True, init_token='<init>', eos_token='<eos>', lower=args.lower, include_lengths=True)
splits = get_all_splits(args, new_vocab)
new_vocab.build_vocab(*splits)
new_vocab.build_vocab(Example.vocab_fields, *splits)
logger.info(f'Vocabulary has {len(FIELD.vocab)} tokens from training')
args.max_generative_vocab = min(len(FIELD.vocab), args.max_generative_vocab)
FIELD.append_vocab(new_vocab)

View File

@ -46,7 +46,7 @@ class AlmondDataset(generic_dataset.CQA):
base_url = None
def __init__(self, path, field, tokenize, contextual=False, reverse_task=False, subsample=None, **kwargs):
def __init__(self, path, contextual=False, reverse_task=False, subsample=None, **kwargs):
cached_path = kwargs.pop('cached_path')
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
@ -100,23 +100,21 @@ class AlmondDataset(generic_dataset.CQA):
examples.append(generic_dataset.Example.from_raw('almond/' + _id,
context, question, answer,
tokenize=str.split, lower=field.lower))
tokenize=str.split, lower=False))
if len(examples) >= max_examples:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
torch.save(examples, cache_name)
super().__init__(examples, field, **kwargs)
super().__init__(examples, **kwargs)
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
@classmethod
def splits(cls, fields, root='.data',
train='train', validation='eval',
test='test', contextual=False, **kwargs):
def splits(cls, root='.data', train='train', validation='eval', test='test', contextual=False, **kwargs):
"""Create dataset objects for splits of the ThingTalk dataset.
Arguments:
@ -138,14 +136,14 @@ class AlmondDataset(generic_dataset.CQA):
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux' + '.tsv'), fields, contextual=contextual, **kwargs)
aux_data = cls(os.path.join(path, 'aux' + '.tsv'), contextual=contextual, **kwargs)
train_data = None if train is None else cls(
os.path.join(path, train + '.tsv'), fields, contextual=contextual, **kwargs)
os.path.join(path, train + '.tsv'), contextual=contextual, **kwargs)
val_data = None if validation is None else cls(
os.path.join(path, validation + '.tsv'), fields, contextual=contextual, **kwargs)
os.path.join(path, validation + '.tsv'), contextual=contextual, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, test + '.tsv'), fields, contextual=contextual, **kwargs)
os.path.join(path, test + '.tsv'), contextual=contextual, **kwargs)
return tuple(d for d in (train_data, val_data, test_data, aux_data)
if d is not None)
@ -180,8 +178,7 @@ class Almond(BaseAlmondTask):
i.e. natural language to formal language (ThingTalk) mapping"""
def get_splits(self, field, root, **kwargs):
return AlmondDataset.splits(
fields=field, root=root, tokenize=self.tokenize, **kwargs)
return AlmondDataset.splits(root=root, lower=field.lower, **kwargs)
@register_task('contextual_almond')

View File

@ -38,10 +38,11 @@ class Multi30K(BaseTask):
def metrics(self):
return ['bleu', 'em', 'nem', 'nf1']
def get_splits(self, field, root, **kwargs):
def get_splits(self, root, **kwargs):
src, trg = ['.' + x for x in self.name.split('.')[1:]]
return generic_dataset.Multi30k.splits(exts=(src, trg),
fields=field, root=root, **kwargs)
root=root,
**kwargs)
@register_task('iwslt')
@ -50,10 +51,11 @@ class IWSLT(BaseTask):
def metrics(self):
return ['bleu', 'em', 'nem', 'nf1']
def get_splits(self, field, root, **kwargs):
def get_splits(self, root, **kwargs):
src, trg = ['.' + x for x in self.name.split('.')[1:]]
return generic_dataset.IWSLT.splits(exts=(src, trg),
fields=field, root=root, **kwargs)
root=root,
**kwargs)
@register_task('squad')
@ -62,9 +64,10 @@ class SQuAD(BaseTask):
def metrics(self):
return ['nf1', 'em', 'nem']
def get_splits(self, field, root, **kwargs):
return generic_dataset.SQuAD.splits(
fields=field, root=root, description=self.name, **kwargs)
def get_splits(self, root, **kwargs):
return generic_dataset.SQuAD.splits(root=root,
description=self.name,
**kwargs)
@register_task('wikisql')
@ -73,19 +76,22 @@ class WikiSQL(BaseTask):
def metrics(self):
return ['lfem', 'em', 'nem', 'nf1']
def get_splits(self, field, root, **kwargs):
def get_splits(self, root, **kwargs):
return generic_dataset.WikiSQL.splits(
fields=field, root=root, query_as_question='query_as_question' in self.name, **kwargs)
root=root,
query_as_question='query_as_question' in self.name,
**kwargs)
@register_task('ontonotes')
class OntoNotesNER(BaseTask):
def get_splits(self, field, root, **kwargs):
def get_splits(self, root, **kwargs):
split_task = self.name.split('.')
_, _, subtask, nones, counting = split_task
return generic_dataset.OntoNotesNER.splits(
subtask=subtask, nones=True if nones == 'nones' else False,
fields=field, root=root, **kwargs)
root=root,
**kwargs)
@register_task('woz')
@ -94,16 +100,19 @@ class WoZ(BaseTask):
def metrics(self):
return ['joint_goal_em', 'turn_request_em', 'turn_goal_em', 'avg_dialogue', 'em', 'nem', 'nf1']
def get_splits(self, field, root, **kwargs):
def get_splits(self, root, **kwargs):
return generic_dataset.WOZ.splits(description=self.name,
fields=field, root=root, **kwargs)
root=root,
**kwargs)
@register_task('multinli')
class MultiNLI(BaseTask):
def get_splits(self, field, root, **kwargs):
def get_splits(self, root, **kwargs):
return generic_dataset.MultiNLI.splits(description=self.name,
fields=field, root=root, **kwargs)
root=root,
**kwargs)
@register_task('srl')
class SRL(BaseTask):
@ -111,20 +120,20 @@ class SRL(BaseTask):
def metrics(self):
return ['nf1', 'em', 'nem']
def get_splits(self, field, root, **kwargs):
return generic_dataset.SRL.splits(fields=field, root=root, **kwargs)
def get_splits(self, root, **kwargs):
return generic_dataset.SRL.splits(root=root, **kwargs)
@register_task('snli')
class SNLI(BaseTask):
def get_splits(self, field, root, **kwargs):
return generic_dataset.SNLI.splits(fields=field, root=root, **kwargs)
def get_splits(self, root, **kwargs):
return generic_dataset.SNLI.splits(root=root, **kwargs)
@register_task('schema')
class WinogradSchema(BaseTask):
def get_splits(self, field, root, **kwargs):
return generic_dataset.WinogradSchema.splits(fields=field, root=root, **kwargs)
def get_splits(self, root, **kwargs):
return generic_dataset.WinogradSchema.splits(root=root, **kwargs)
class BaseSummarizationTask(BaseTask):
@ -142,23 +151,21 @@ class BaseSummarizationTask(BaseTask):
@register_task('cnn')
class CNN(BaseSummarizationTask):
def get_splits(self, field, root, **kwargs):
return generic_dataset.CNN.splits(fields=field, root=root, **kwargs)
def get_splits(self, root, **kwargs):
return generic_dataset.CNN.splits(root=root, **kwargs)
@register_task('dailymail')
class DailyMail(BaseSummarizationTask):
def get_splits(self, field, root, **kwargs):
return generic_dataset.DailyMail.splits(fields=field, root=root, **kwargs)
def get_splits(self, root, **kwargs):
return generic_dataset.DailyMail.splits(root=root, **kwargs)
@register_task('cnn_dailymail')
class CNNDailyMail(BaseSummarizationTask):
def get_splits(self, field, root, **kwargs):
split_cnn = generic_dataset.CNN.splits(
fields=field, root=root, **kwargs)
split_dm = generic_dataset.DailyMail.splits(
fields=field, root=root, **kwargs)
def get_splits(self, root, **kwargs):
split_cnn = generic_dataset.CNN.splits(root=root, **kwargs)
split_dm = generic_dataset.DailyMail.splits(root=root, **kwargs)
for scnn, sdm in zip(split_cnn, split_dm):
scnn.examples.extend(sdm)
return split_cnn
@ -166,9 +173,8 @@ class CNNDailyMail(BaseSummarizationTask):
@register_task('sst')
class SST(BaseTask):
def get_splits(self, field, root, **kwargs):
return generic_dataset.SST.splits(
fields=field, root=root, **kwargs)
def get_splits(self, root, **kwargs):
return generic_dataset.SST.splits(root=root, **kwargs)
@register_task('imdb')
@ -176,9 +182,9 @@ class IMDB(BaseTask):
def preprocess_example(self, ex, train=False, max_context_length=None):
return ex._replace(context=ex.context[:max_context_length])
def get_splits(self, field, root, **kwargs):
def get_splits(self, root, **kwargs):
kwargs['validation'] = None
return generic_dataset.IMDb.splits(fields=field, root=root, **kwargs)
return generic_dataset.IMDb.splits(root=root, **kwargs)
@register_task('zre')
@ -187,5 +193,5 @@ class ZRE(BaseTask):
def metrics(self):
return ['corpus_f1', 'precision', 'recall', 'em', 'nem', 'nf1']
def get_splits(self, field, root, **kwargs):
return generic_dataset.ZeroShotRE.splits(fields=field, root=root, **kwargs)
def get_splits(self, root, **kwargs):
return generic_dataset.ZeroShotRE.splits(root=root, **kwargs)

View File

@ -51,10 +51,6 @@ def make_example_id(dataset, example_id):
class CQA(data.Dataset):
def __init__(self, examples, field, **kwargs):
fields = [(x, field) for x in Example._fields]
super().__init__(examples, fields, **kwargs)
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
@ -69,7 +65,7 @@ class IMDb(CQA):
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
def __init__(self, path, field, subsample=None, **kwargs):
def __init__(self, path, subsample=None, tokenize=None, lower=False, **kwargs):
examples = []
labels = {'neg': 'negative', 'pos': 'positive'}
question = 'Is this review negative or positive?'
@ -88,30 +84,29 @@ class IMDb(CQA):
answer = labels[label]
examples.append(Example.from_raw(make_example_id(self, len(examples)),
context, question, answer,
tokenize=field.tokenize, lower=field.lower))
tokenize=tokenize, lower=lower))
if subsample is not None and len(examples) > subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
torch.save(examples, cache_name)
super().__init__(examples, field, **kwargs)
super().__init__(examples, **kwargs)
@classmethod
def splits(cls, fields, root='.data',
train='train', validation=None, test='test', **kwargs):
def splits(cls, root='.data', train='train', validation=None, test='test', **kwargs):
assert validation is None
path = cls.download(root)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux'), fields, **kwargs)
aux_data = cls(os.path.join(path, 'aux'), **kwargs)
train_data = None if train is None else cls(
os.path.join(path, f'{train}'), fields, **kwargs)
os.path.join(path, f'{train}'), **kwargs)
test_data = None if test is None else cls(
os.path.join(path, f'{test}'), fields, **kwargs)
os.path.join(path, f'{test}'), **kwargs)
return tuple(d for d in (train_data, test_data, aux_data)
if d is not None)
@ -128,7 +123,7 @@ class SST(CQA):
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
def __init__(self, path, field, subsample=None, **kwargs):
def __init__(self, path, subsample=None, tokenize=None, lower=False, **kwargs):
cached_path = kwargs.pop('cached_path')
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
@ -149,7 +144,7 @@ class SST(CQA):
answer = labels[int(parsed[0])]
examples.append(Example.from_raw(make_example_id(self, len(examples)),
context, question, answer,
tokenize=field.tokenize, lower=field.lower))
tokenize=tokenize, lower=lower))
if subsample is not None and len(examples) > subsample:
break
@ -159,25 +154,24 @@ class SST(CQA):
torch.save(examples, cache_name)
self.examples = examples
super().__init__(examples, field, **kwargs)
super().__init__(examples, **kwargs)
@classmethod
def splits(cls, fields, root='.data',
train='train', validation='dev', test='test', **kwargs):
def splits(cls, root='.data', train='train', validation='dev', test='test', **kwargs):
path = cls.download(root)
postfix = f'_binary_sent.csv'
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, f'aux{postfix}'), fields, **kwargs)
aux_data = cls(os.path.join(path, f'aux{postfix}'), **kwargs)
train_data = None if train is None else cls(
os.path.join(path, f'{train}{postfix}'), fields, **kwargs)
os.path.join(path, f'{train}{postfix}'), **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, f'{validation}{postfix}'), fields, **kwargs)
os.path.join(path, f'{validation}{postfix}'), **kwargs)
test_data = None if test is None else cls(
os.path.join(path, f'{test}{postfix}'), fields, **kwargs)
os.path.join(path, f'{test}{postfix}'), **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
if d is not None)
@ -188,7 +182,7 @@ class TranslationDataset(CQA):
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
def __init__(self, path, exts, field, subsample=None, **kwargs):
def __init__(self, path, exts, subsample=None, tokenize=None, lower=False, **kwargs):
"""Create a TranslationDataset given paths and fields.
Arguments:
@ -220,7 +214,7 @@ class TranslationDataset(CQA):
answer = trg_line
examples.append(Example.from_raw(make_example_id(self, len(examples)),
context, question, answer,
tokenize=field.tokenize, lower=field.lower))
tokenize=tokenize, lower=lower))
if subsample is not None and len(examples) >= subsample:
break
@ -228,11 +222,10 @@ class TranslationDataset(CQA):
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
torch.save(examples, cache_name)
super().__init__(examples, field, **kwargs)
super().__init__(examples, **kwargs)
@classmethod
def splits(cls, exts, fields, root='.data',
train='train', validation='val', test='test', **kwargs):
def splits(cls, exts, root='.data', train='train', validation='val', test='test', **kwargs):
"""Create dataset objects for splits of a TranslationDataset.
Arguments:
@ -252,14 +245,14 @@ class TranslationDataset(CQA):
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux'), exts, fields, **kwargs)
aux_data = cls(os.path.join(path, 'aux'), exts, **kwargs)
train_data = None if train is None else cls(
os.path.join(path, train), exts, fields, **kwargs)
os.path.join(path, train), exts, **kwargs)
val_data = None if validation is None else cls(
os.path.join(path, validation), exts, fields, **kwargs)
os.path.join(path, validation), exts, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, test), exts, fields, **kwargs)
os.path.join(path, test), exts, **kwargs)
return tuple(d for d in (train_data, val_data, test_data, aux_data)
if d is not None)
@ -279,7 +272,7 @@ class IWSLT(TranslationDataset, CQA):
base_dirname = '{}-{}'
@classmethod
def splits(cls, exts, fields, root='.data',
def splits(cls, exts, root='.data',
train='train', validation='IWSLT16.TED.tst2013',
test='IWSLT16.TED.tst2014', **kwargs):
"""Create dataset objects for splits of the IWSLT dataset.
@ -305,7 +298,7 @@ class IWSLT(TranslationDataset, CQA):
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux = '.'.join(['aux', cls.dirname])
aux_data = cls(os.path.join(path, aux), exts, fields, **kwargs)
aux_data = cls(os.path.join(path, aux), exts, **kwargs)
if train is not None:
train = '.'.join([train, cls.dirname])
@ -318,11 +311,11 @@ class IWSLT(TranslationDataset, CQA):
cls.clean(path)
train_data = None if train is None else cls(
os.path.join(path, train), exts, fields, **kwargs)
os.path.join(path, train), exts, **kwargs)
val_data = None if validation is None else cls(
os.path.join(path, validation), exts, fields, **kwargs)
os.path.join(path, validation), exts, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, test), exts, fields, **kwargs)
os.path.join(path, test), exts, **kwargs)
return tuple(d for d in (train_data, val_data, test_data, aux_data)
if d is not None)
@ -349,7 +342,7 @@ class IWSLT(TranslationDataset, CQA):
fd_txt.write(l.strip() + '\n')
class SQuAD(CQA, data.Dataset):
class SQuAD(CQA):
@staticmethod
def sort_key(ex):
@ -362,7 +355,7 @@ class SQuAD(CQA, data.Dataset):
name = 'squad'
dirname = ''
def __init__(self, path, field, subsample=None, **kwargs):
def __init__(self, path, subsample=None, lower=False, **kwargs):
cached_path = kwargs.pop('cached_path')
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
@ -389,7 +382,7 @@ class SQuAD(CQA, data.Dataset):
context = ' '.join(context.split())
ex = Example.from_raw(make_example_id(self, qa['id']),
context, question, answer,
tokenize=str.split, lower=field.lower)
tokenize=str.split, lower=lower)
else:
answer = qa['answers'][0]['text']
all_answers.append([a['text'] for a in qa['answers']])
@ -435,8 +428,7 @@ class SQuAD(CQA, data.Dataset):
indexed_answer = tagged_context[context_spans[0]:context_spans[-1]+1]
if len(indexed_answer) != len(tokenized_answer):
import pdb; pdb.set_trace()
if field.eos_token is not None:
context_spans += [len(tagged_context)]
context_spans += [len(tagged_context)]
for context_idx, answer_word in zip(context_spans, ex.answer):
if context_idx == len(tagged_context):
continue
@ -445,7 +437,7 @@ class SQuAD(CQA, data.Dataset):
ex = Example.from_raw(make_example_id(self, qa['id']),
' '.join(tagged_context), question, ' '.join(tokenized_answer),
tokenize=str.split, lower=field.lower)
tokenize=str.split, lower=lower)
examples.append(ex)
if subsample is not None and len(examples) > subsample:
@ -459,14 +451,13 @@ class SQuAD(CQA, data.Dataset):
logger.info(f'Caching data to {cache_name}')
torch.save((examples, all_answers, q_ids), cache_name)
super(SQuAD, self).__init__(examples, field, **kwargs)
super(SQuAD, self).__init__(examples, **kwargs)
self.all_answers = all_answers
self.q_ids = q_ids
@classmethod
def splits(cls, fields, root='.data', description='squad1.1',
train='train', validation='dev', test=None, **kwargs):
def splits(cls, root='.data', description='squad1.1', train='train', validation='dev', test=None, **kwargs):
"""Create dataset objects for splits of the SQuAD dataset.
Arguments:
root: directory containing SQuAD data
@ -484,15 +475,15 @@ class SQuAD(CQA, data.Dataset):
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux = '-'.join(['aux', extension])
aux_data = cls(os.path.join(path, aux), fields, **kwargs)
aux_data = cls(os.path.join(path, aux), **kwargs)
train = '-'.join([train, extension]) if train is not None else None
validation = '-'.join([validation, extension]) if validation is not None else None
train_data = None if train is None else cls(
os.path.join(path, train), fields, **kwargs)
os.path.join(path, train), **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, validation), fields, **kwargs)
os.path.join(path, validation), **kwargs)
return tuple(d for d in (train_data, validation_data, aux_data)
if d is not None)
@ -510,13 +501,13 @@ def fix_missing_period(line):
return line + "."
class Summarization(CQA, data.Dataset):
class Summarization(CQA):
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
def __init__(self, path, field, one_answer=True, subsample=None, **kwargs):
def __init__(self, path, one_answer=True, subsample=None, tokenize=None, lower=False, **kwargs):
cached_path = kwargs.pop('cached_path')
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
@ -533,14 +524,14 @@ class Summarization(CQA, data.Dataset):
context, question, answer = ex['context'], ex['question'], ex['answer']
examples.append(Example.from_raw(make_example_id(self, len(examples)),
context, question, answer,
tokenize=field.tokenize, lower=field.lower))
tokenize=tokenize, lower=lower))
if subsample is not None and len(examples) >= subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
torch.save(examples, cache_name)
super(Summarization, self).__init__(examples, field, **kwargs)
super(Summarization, self).__init__(examples, **kwargs)
@classmethod
def cache_splits(cls, path):
@ -593,22 +584,21 @@ class Summarization(CQA, data.Dataset):
@classmethod
def splits(cls, fields, root='.data',
train='training', validation='validation', test='test', **kwargs):
def splits(cls, root='.data', train='training', validation='validation', test='test', **kwargs):
path = cls.download(root)
cls.cache_splits(path)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'auxiliary.jsonl'), fields, **kwargs)
aux_data = cls(os.path.join(path, 'auxiliary.jsonl'), **kwargs)
train_data = None if train is None else cls(
os.path.join(path, 'training.jsonl'), fields, **kwargs)
os.path.join(path, 'training.jsonl'), **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, 'validation.jsonl'), fields, one_answer=False, **kwargs)
os.path.join(path, 'validation.jsonl'), one_answer=False, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, 'test.jsonl'), fields, one_answer=False, **kwargs)
os.path.join(path, 'test.jsonl'), one_answer=False, **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
if d is not None)
@ -659,7 +649,7 @@ class Query:
return cls(sel_index=d['sel'], agg_index=d['agg'], columns=t, conditions=d['conds'])
class WikiSQL(CQA, data.Dataset):
class WikiSQL(CQA):
@staticmethod
def sort_key(ex):
@ -669,7 +659,7 @@ class WikiSQL(CQA, data.Dataset):
name = 'wikisql'
dirname = 'data'
def __init__(self, path, field, query_as_question=False, subsample=None, **kwargs):
def __init__(self, path, query_as_question=False, subsample=None, tokenize=None, lower=False, **kwargs):
cached_path = kwargs.pop('cached_path')
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', 'query_as_question' if query_as_question else 'query_as_context', os.path.basename(path), str(subsample))
skip_cache_bool = kwargs.pop('skip_cache_bool')
@ -705,7 +695,7 @@ class WikiSQL(CQA, data.Dataset):
context += f'-- {human_query}'
examples.append(Example.from_raw(make_example_id(self, idx),
context, question, answer,
tokenize=field.tokenize, lower=field.lower))
tokenize=tokenize, lower=lower))
all_answers.append({'sql': sql, 'header': header, 'answer': answer, 'table': table})
if subsample is not None and len(examples) > subsample:
break
@ -714,13 +704,12 @@ class WikiSQL(CQA, data.Dataset):
logger.info(f'Caching data to {cache_name}')
torch.save((examples, all_answers), cache_name)
super(WikiSQL, self).__init__(examples, field, **kwargs)
super(WikiSQL, self).__init__(examples, **kwargs)
self.all_answers = all_answers
@classmethod
def splits(cls, fields, root='.data',
train='train.jsonl', validation='dev.jsonl', test='test.jsonl', **kwargs):
def splits(cls, root='.data', train='train.jsonl', validation='dev.jsonl', test='test.jsonl', **kwargs):
"""Create dataset objects for splits of the SQuAD dataset.
Arguments:
root: directory containing SQuAD data
@ -735,19 +724,19 @@ class WikiSQL(CQA, data.Dataset):
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux'), fields, **kwargs)
aux_data = cls(os.path.join(path, 'aux'), **kwargs)
train_data = None if train is None else cls(
os.path.join(path, train), fields, **kwargs)
os.path.join(path, train), **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, validation), fields, **kwargs)
os.path.join(path, validation), **kwargs)
test_data = None if test is None else cls(
os.path.join(path, test), fields, **kwargs)
os.path.join(path, test), **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
if d is not None)
class SRL(CQA, data.Dataset):
class SRL(CQA):
@staticmethod
def sort_key(ex):
@ -762,9 +751,10 @@ class SRL(CQA, data.Dataset):
@classmethod
def clean(cls, s):
closing_punctuation = set([ ' .', ' ,', ' ;', ' !', ' ?', ' :', ' )', " 'll", " n't ", " %", " 't", " 's", " 'm", " 'd", " 're"])
opening_punctuation = set(['( ', '$ '])
both_sides = set([' - '])
closing_punctuation = {' .', ' ,', ' ;', ' !', ' ?', ' :', ' )', " 'll", " n't ", " %", " 't", " 's", " 'm",
" 'd", " 're"}
opening_punctuation = {'( ', '$ '}
both_sides = {' - '}
s = ' '.join(s.split()).strip()
s = s.replace('-LRB-', '(')
s = s.replace('-RRB-', ')')
@ -787,7 +777,7 @@ class SRL(CQA, data.Dataset):
s = s.replace(" '", '')
return ' '.join(s.split()).strip()
def __init__(self, path, field, one_answer=True, subsample=None, **kwargs):
def __init__(self, path, one_answer=True, subsample=None, tokenize=None, lower=False, **kwargs):
cached_path = kwargs.pop('cached_path')
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
@ -805,7 +795,7 @@ class SRL(CQA, data.Dataset):
context, question, answer = ex['context'], ex['question'], ex['answer']
examples.append(Example.from_raw(make_example_id(self, len(all_answers)),
context, question, answer,
tokenize=field.tokenize, lower=field.lower))
tokenize=tokenize, lower=lower))
all_answers.append(aa)
if subsample is not None and len(examples) >= subsample:
break
@ -813,7 +803,7 @@ class SRL(CQA, data.Dataset):
logger.info(f'Caching data to {cache_name}')
torch.save((examples, all_answers), cache_name)
super(SRL, self).__init__(examples, field, **kwargs)
super(SRL, self).__init__(examples, **kwargs)
self.all_answers = all_answers
@ -909,22 +899,21 @@ class SRL(CQA, data.Dataset):
@classmethod
def splits(cls, fields, root='.data',
train='train', validation='dev', test='test', **kwargs):
def splits(cls, root='.data', train='train', validation='dev', test='test', **kwargs):
path = cls.download(root)
cls.cache_splits(path)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs)
train_data = None if train is None else cls(
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
os.path.join(path, f'{train}.jsonl'), **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, f'{validation}.jsonl'), fields, one_answer=False, **kwargs)
os.path.join(path, f'{validation}.jsonl'), one_answer=False, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, f'{test}.jsonl'), fields, one_answer=False, **kwargs)
os.path.join(path, f'{test}.jsonl'), one_answer=False, **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
if d is not None)
@ -940,7 +929,7 @@ class WinogradSchema(CQA, data.Dataset):
name = 'schema'
dirname = ''
def __init__(self, path, field, subsample=None, **kwargs):
def __init__(self, path, subsample=None, tokenize=None, lower=False, **kwargs):
cached_path = kwargs.pop('cached_path')
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
skip_cache_bool = kwargs.pop('skip_cache_bool')
@ -955,14 +944,14 @@ class WinogradSchema(CQA, data.Dataset):
context, question, answer = ex['context'], ex['question'], ex['answer']
examples.append(Example.from_raw(make_example_id(self, len(examples)),
context, question, answer,
tokenize=field.tokenize, lower=field.lower))
tokenize=tokenize, lower=lower))
if subsample is not None and len(examples) >= subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
torch.save(examples, cache_name)
super(WinogradSchema, self).__init__(examples, field, **kwargs)
super(WinogradSchema, self).__init__(examples, **kwargs)
@classmethod
def cache_splits(cls, path):
@ -1021,22 +1010,21 @@ class WinogradSchema(CQA, data.Dataset):
@classmethod
def splits(cls, fields, root='.data',
train='train', validation='validation', test='test', **kwargs):
def splits(cls, root='.data', train='train', validation='validation', test='test', **kwargs):
path = cls.download(root)
cls.cache_splits(path)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs)
train_data = None if train is None else cls(
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
os.path.join(path, f'{train}.jsonl'), **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
os.path.join(path, f'{validation}.jsonl'), **kwargs)
test_data = None if test is None else cls(
os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
os.path.join(path, f'{test}.jsonl'), **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
if d is not None)
@ -1058,10 +1046,11 @@ class WOZ(CQA, data.Dataset):
name = 'woz'
dirname = ''
def __init__(self, path, field, subsample=None, description='woz.en', **kwargs):
def __init__(self, path, subsample=None, tokenize=None, lower=False, description='woz.en', **kwargs):
examples, all_answers = [], []
cached_path = kwargs.pop('cached_path')
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample), description)
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path),
str(subsample), description)
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool:
logger.info(f'Loading cached data from {cache_name}')
@ -1075,7 +1064,7 @@ class WOZ(CQA, data.Dataset):
all_answers.append((ex['lang_dialogue_turn'], answer))
examples.append(Example.from_raw(make_example_id(self, woz_id),
context, question, answer,
tokenize=field.tokenize, lower=field.lower))
tokenize=tokenize, lower=lower))
if subsample is not None and len(examples) >= subsample:
break
@ -1083,7 +1072,7 @@ class WOZ(CQA, data.Dataset):
logger.info(f'Caching data to {cache_name}')
torch.save((examples, all_answers), cache_name)
super(WOZ, self).__init__(examples, field, **kwargs)
super(WOZ, self).__init__(examples, **kwargs)
self.all_answers = all_answers
@classmethod
@ -1150,26 +1139,26 @@ class WOZ(CQA, data.Dataset):
@classmethod
def splits(cls, fields, root='.data', train='train', validation='validate', test='test', **kwargs):
def splits(cls, root='.data', train='train', validation='validate', test='test', **kwargs):
path = cls.download(root)
cls.cache_splits(path)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs)
train_data = None if train is None else cls(
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
os.path.join(path, f'{train}.jsonl'), **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
os.path.join(path, f'{validation}.jsonl'), **kwargs)
test_data = None if test is None else cls(
os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
os.path.join(path, f'{test}.jsonl'), **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
if d is not None)
class MultiNLI(CQA, data.Dataset):
class MultiNLI(CQA):
@staticmethod
def sort_key(ex):
@ -1180,9 +1169,10 @@ class MultiNLI(CQA, data.Dataset):
name = 'multinli'
dirname = 'multinli_1.0'
def __init__(self, path, field, subsample=None, description='multinli.in.out', **kwargs):
def __init__(self, path, subsample=None, tokenize=None, lower=False, description='multinli.in.out', **kwargs):
cached_path = kwargs.pop('cached_path')
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample), description)
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path),
str(subsample), description)
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool:
logger.info(f'Loading cached data from {cache_name}')
@ -1196,14 +1186,14 @@ class MultiNLI(CQA, data.Dataset):
context, question, answer = ex['context'], ex['question'], ex['answer']
examples.append(Example.from_raw(make_example_id(self, len(examples)),
context, question, answer,
tokenize=field.tokenize, lower=field.lower))
tokenize=tokenize, lower=lower))
if subsample is not None and len(examples) >= subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
torch.save(examples, cache_name)
super(MultiNLI, self).__init__(examples, field, **kwargs)
super(MultiNLI, self).__init__(examples, **kwargs)
@classmethod
def cache_splits(cls, path, train='multinli_1.0_train', validation='mulinli_1.0_dev_{}', test='test'):
@ -1234,26 +1224,26 @@ class MultiNLI(CQA, data.Dataset):
@classmethod
def splits(cls, fields, root='.data', train='train', validation='validation', test='test', **kwargs):
def splits(cls, root='.data', train='train', validation='validation', test='test', **kwargs):
path = cls.download(root)
cls.cache_splits(path)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs)
train_data = None if train is None else cls(
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
os.path.join(path, f'{train}.jsonl'), **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
os.path.join(path, f'{validation}.jsonl'), **kwargs)
test_data = None if test is None else cls(
os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
os.path.join(path, f'{test}.jsonl'), **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
if d is not None)
class ZeroShotRE(CQA, data.Dataset):
class ZeroShotRE(CQA):
@staticmethod
def sort_key(ex):
@ -1264,9 +1254,10 @@ class ZeroShotRE(CQA, data.Dataset):
name = 'zre'
def __init__(self, path, field, subsample=None, **kwargs):
def __init__(self, path, subsample=None, tokenize=None, lower=False, **kwargs):
cached_path = kwargs.pop('cached_path')
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path),
str(subsample))
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool:
logger.info(f'Loading cached data from {cache_name}')
@ -1279,7 +1270,7 @@ class ZeroShotRE(CQA, data.Dataset):
context, question, answer = ex['context'], ex['question'], ex['answer']
examples.append(Example.from_raw(make_example_id(self, len(examples)),
context, question, answer,
tokenize=field.tokenize, lower=field.lower))
tokenize=tokenize, lower=lower))
if subsample is not None and len(examples) >= subsample:
break
@ -1287,7 +1278,7 @@ class ZeroShotRE(CQA, data.Dataset):
logger.info(f'Caching data to {cache_name}')
torch.save(examples, cache_name)
super().__init__(examples, field, **kwargs)
super().__init__(examples, **kwargs)
@classmethod
def cache_splits(cls, path, train='train', validation='dev', test='test'):
@ -1316,21 +1307,21 @@ class ZeroShotRE(CQA, data.Dataset):
@classmethod
def splits(cls, fields, root='.data', train='train', validation='dev', test='test', **kwargs):
def splits(cls, root='.data', train='train', validation='dev', test='test', **kwargs):
path = cls.download(root)
cls.cache_splits(path)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs)
train_data = None if train is None else cls(
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
os.path.join(path, f'{train}.jsonl'), **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
os.path.join(path, f'{validation}.jsonl'), **kwargs)
test_data = None if test is None else cls(
os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
os.path.join(path, f'{test}.jsonl'), **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
if d is not None)
@ -1396,7 +1387,9 @@ class OntoNotesNER(CQA, data.Dataset):
return ' '.join(raw.split()).strip()
def __init__(self, path, field, one_answer=True, subsample=None, path_to_files='.data/ontonotes-release-5.0/data/files', subtask='all', nones=True, **kwargs):
def __init__(self, path, one_answer=True, subsample=None, tokenize=None, lower=False,
path_to_files='.data/ontonotes-release-5.0/data/files',
subtask='all', nones=True, **kwargs):
cached_path = kwargs.pop('cached_path')
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample), subtask, str(nones))
skip_cache_bool = kwargs.pop('skip_cache_bool')
@ -1416,7 +1409,7 @@ class OntoNotesNER(CQA, data.Dataset):
context, question, answer = ex['context'], ex['question'], ex['answer']
examples.append(Example.from_raw(make_example_id(self, len(examples)),
context, question, answer,
tokenize=field.tokenize, lower=field.lower))
tokenize=tokenize, lower=lower))
if subsample is not None and len(examples) >= subsample:
break
@ -1424,7 +1417,7 @@ class OntoNotesNER(CQA, data.Dataset):
logger.info(f'Caching data to {cache_name}')
torch.save(examples, cache_name)
super(OntoNotesNER, self).__init__(examples, field, **kwargs)
super(OntoNotesNER, self).__init__(examples, **kwargs)
@classmethod
@ -1554,8 +1547,7 @@ class OntoNotesNER(CQA, data.Dataset):
@classmethod
def splits(cls, fields, root='.data',
train='train', validation='development', test='test', **kwargs):
def splits(cls, root='.data', train='train', validation='development', test='test', **kwargs):
path_to_files = os.path.join(root, 'ontonotes-release-5.0', 'data', 'files')
assert os.path.exists(path_to_files)
path = cls.download(root)
@ -1564,18 +1556,18 @@ class OntoNotesNER(CQA, data.Dataset):
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs)
train_data = None if train is None else cls(
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
os.path.join(path, f'{train}.jsonl'), **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, f'{validation}.jsonl'), fields, one_answer=False, **kwargs)
os.path.join(path, f'{validation}.jsonl'), one_answer=False, **kwargs)
test_data = None if test is None else cls(
os.path.join(path, f'{test}.jsonl'), fields, one_answer=False, **kwargs)
os.path.join(path, f'{test}.jsonl'), one_answer=False, **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
if d is not None)
class SNLI(CQA, data.Dataset):
class SNLI(CQA):
@staticmethod
def sort_key(ex):
@ -1586,9 +1578,10 @@ class SNLI(CQA, data.Dataset):
name = 'snli'
def __init__(self, path, field, subsample=None, **kwargs):
def __init__(self, path, subsample=None, tokenize=None, lower=False, **kwargs):
cached_path = kwargs.pop('cached_path')
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path),
str(subsample))
skip_cache_bool = kwargs.pop('skip_cache_bool')
if os.path.exists(cache_name) and not skip_cache_bool:
logger.info(f'Loading cached data from {cache_name}')
@ -1602,7 +1595,7 @@ class SNLI(CQA, data.Dataset):
context, question, answer = ex['context'], ex['question'], ex['answer']
examples.append(Example.from_raw(make_example_id(self, len(examples)),
context, question, answer,
tokenize=field.tokenize, lower=field.lower))
tokenize=tokenize, lower=lower))
if subsample is not None and len(examples) >= subsample:
break
@ -1610,7 +1603,7 @@ class SNLI(CQA, data.Dataset):
logger.info(f'Caching data to {cache_name}')
torch.save(examples, cache_name)
super().__init__(examples, field, **kwargs)
super().__init__(examples, **kwargs)
@classmethod
def cache_splits(cls, path, train='train', validation='dev', test='test'):
@ -1632,34 +1625,36 @@ class SNLI(CQA, data.Dataset):
@classmethod
def splits(cls, fields, root='.data', train='train', validation='dev', test='test', **kwargs):
def splits(cls, root='.data', train='train', validation='dev', test='test', **kwargs):
path = cls.download(root)
cls.cache_splits(path)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs)
train_data = None if train is None else cls(
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
os.path.join(path, f'{train}.jsonl'), **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
os.path.join(path, f'{validation}.jsonl'), **kwargs)
test_data = None if test is None else cls(
os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
os.path.join(path, f'{test}.jsonl'), **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
if d is not None)
class JSON(CQA, data.Dataset):
class JSON(CQA):
name = 'json'
@staticmethod
def sort_key(ex):
return data.interleave_keys(len(ex.context), len(ex.answer))
def __init__(self, path, field, subsample=None, **kwargs):
def __init__(self, path, subsample=None, tokenize=None, lower=False, **kwargs):
cached_path = kwargs.pop('cached_path')
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path),
str(subsample))
examples = []
skip_cache_bool = kwargs.pop('skip_cache_bool')
@ -1674,30 +1669,29 @@ class JSON(CQA, data.Dataset):
context, question, answer = ex['context'], ex['question'], ex['answer']
examples.append(Example.from_raw(make_example_id(self, len(examples)),
context, question, answer,
tokenize=field.tokenize, lower=field.lower))
tokenize=tokenize, lower=lower))
if subsample is not None and len(examples) >= subsample:
break
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
logger.info(f'Caching data to {cache_name}')
torch.save(examples, cache_name)
super(JSON, self).__init__(examples, field, **kwargs)
super(JSON, self).__init__(examples, **kwargs)
@classmethod
def splits(cls, fields, name, root='.data',
train='train', validation='val', test='test', **kwargs):
def splits(cls, name, root='.data', train='train', validation='val', test='test', **kwargs):
path = os.path.join(root, name)
aux_data = None
if kwargs.get('curriculum', False):
kwargs.pop('curriculum')
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs)
train_data = None if train is None else cls(
os.path.join(path, 'train.jsonl'), fields, **kwargs)
os.path.join(path, 'train.jsonl'), **kwargs)
validation_data = None if validation is None else cls(
os.path.join(path, 'val.jsonl'), fields, **kwargs)
os.path.join(path, 'val.jsonl'), **kwargs)
test_data = None if test is None else cls(
os.path.join(path, 'test.jsonl'), fields, **kwargs)
os.path.join(path, 'test.jsonl'), **kwargs)
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
if d is not None)

View File

@ -18,19 +18,14 @@ class Dataset(torch.utils.data.Dataset):
fields: A dictionary containing the name of each column together with
its corresponding Field object. Two columns with the same Field
object will share a vocabulary.
fields (dict[str, Field]): Contains the name of each column or field, together
with the corresponding Field object. Two fields with the same Field object
will have a shared vocabulary.
"""
sort_key = None
def __init__(self, examples, fields, filter_pred=None):
def __init__(self, examples, filter_pred=None, **kwargs):
"""Create a dataset from a list of Examples and Fields.
Arguments:
examples: List of Examples.
fields (List(tuple(str, Field))): The Fields to use in this tuple. The
string is a field name, and the Field is the associated field.
filter_pred (callable or None): Use only examples for which
filter_pred(example) is True, or use all examples if None.
Default is None.
@ -41,16 +36,13 @@ class Dataset(torch.utils.data.Dataset):
if make_list:
examples = list(examples)
self.examples = examples
self.fields = dict(fields)
@classmethod
def splits(cls, path=None, root='.data', train=None, validation=None,
def splits(cls, root='.data', train=None, validation=None,
test=None, **kwargs):
"""Create Dataset objects for multiple splits of a dataset.
Arguments:
path (str): Common prefix of the splits' file paths, or None to use
the result of cls.download(root).
root (str): Root dataset storage directory. Default is '.data'.
train (str): Suffix to add to path for the train set, or None for no
train set. Default is None.
@ -65,8 +57,7 @@ class Dataset(torch.utils.data.Dataset):
split_datasets (tuple(Dataset)): Datasets for train, validation, and
test splits in that order, if provided.
"""
if path is None:
path = cls.download(root)
path = cls.download(root)
train_data = None if train is None else cls(
os.path.join(path, train), **kwargs)
val_data = None if validation is None else cls(
@ -89,11 +80,6 @@ class Dataset(torch.utils.data.Dataset):
for x in self.examples:
yield x
def __getattr__(self, attr):
if attr in self.fields:
for x in self.examples:
yield getattr(x, attr)
@classmethod
def download(cls, root, check=None):
"""Download and unzip an online archive (.zip, .gz, or .tgz).

View File

@ -244,7 +244,7 @@ class Field(RawField):
return (padded, lengths)
return padded
def build_vocab(self, *args, **kwargs):
def build_vocab(self, field_names, *args, **kwargs):
"""Construct the Vocab object for this field from one or more datasets.
Arguments:
@ -259,11 +259,7 @@ class Field(RawField):
counter = Counter()
sources = []
for arg in args:
if hasattr(arg, 'fields'):
sources += [getattr(arg, name) for name, field in
arg.fields.items() if field is self]
else:
sources.append(arg)
sources += [getattr(ex, name) for name in field_names for ex in arg]
for data in sources:
for x in data:
if not self.sequential:

View File

@ -34,15 +34,11 @@ import math
import time
import sys
from copy import deepcopy
import logging
from pprint import pformat
from logging import handlers
import numpy as np
from .utils.model_utils import init_model
import torch
from tensorboardX import SummaryWriter
from . import arguments
@ -53,6 +49,8 @@ from .utils.saver import Saver
from .utils.embeddings import load_embeddings
from .text.data import ReversibleField
from .data.numericalizer import DecoderVocabulary
from .data.example import Example
from .utils.model_utils import init_model
def initialize_logger(args, rank='main'):
@ -78,12 +76,9 @@ def log(rank='main'):
def prepare_data(args, field, logger):
if field is None:
logger.info(f'Constructing field')
FIELD = ReversibleField(batch_first=True, init_token='<init>', eos_token='<eos>', lower=args.lower, include_lengths=True)
else:
FIELD = field
field = ReversibleField(batch_first=True, init_token='<init>', eos_token='<eos>', lower=args.lower, include_lengths=True)
train_sets, val_sets, aux_sets, vocab_sets = [], [], [], []
for task in args.train_tasks:
@ -97,7 +92,7 @@ def prepare_data(args, field, logger):
kwargs['cached_path'] = args.cached
logger.info(f'Adding {task.name} to training datasets')
split = task.get_splits(FIELD, args.data, **kwargs)
split = task.get_splits(args.data, tokenize=field.tokenize, lower=args.lower, **kwargs)
if args.use_curriculum:
assert len(split) == 2
aux_sets.append(split[1])
@ -118,7 +113,7 @@ def prepare_data(args, field, logger):
kwargs['cached_path'] = args.cached
logger.info(f'Adding {task.name} to validation datasets')
split = task.get_splits(FIELD, args.data, **kwargs)
split = task.get_splits(args.data, tokenize=field.tokenize, lower=args.lower, **kwargs)
assert len(split) == 1
logger.info(f'{task.name} has {len(split[0])} validation examples')
val_sets.append(split[0])
@ -129,23 +124,23 @@ def prepare_data(args, field, logger):
vectors = load_embeddings(args, logger)
vocab_sets = (train_sets + val_sets) if len(vocab_sets) == 0 else vocab_sets
logger.info(f'Building vocabulary')
FIELD.build_vocab(*vocab_sets, max_size=args.max_effective_vocab, vectors=vectors)
field.build_vocab(Example.vocab_fields, *vocab_sets, max_size=args.max_effective_vocab, vectors=vectors)
FIELD.decoder_vocab = DecoderVocabulary(FIELD.vocab.itos[:args.max_generative_vocab], FIELD.vocab)
field.decoder_vocab = DecoderVocabulary(field.vocab.itos[:args.max_generative_vocab], field.vocab)
logger.info(f'Vocabulary has {len(FIELD.vocab)} tokens')
logger.info(f'Vocabulary has {len(field.vocab)} tokens')
logger.debug(f'The first 200 tokens:')
logger.debug(FIELD.vocab.itos[:200])
logger.debug(field.vocab.itos[:200])
if args.use_curriculum:
logger.info('Preprocessing auxiliary data for curriculum')
preprocess_examples(args, args.train_tasks, aux_sets, FIELD, logger, train=True)
preprocess_examples(args, args.train_tasks, aux_sets, field, logger, train=True)
logger.info('Preprocessing training data')
preprocess_examples(args, args.train_tasks, train_sets, FIELD, logger, train=True)
preprocess_examples(args, args.train_tasks, train_sets, field, logger, train=True)
logger.info('Preprocessing validation data')
preprocess_examples(args, args.val_tasks, val_sets, FIELD, logger, train=args.val_filter)
preprocess_examples(args, args.val_tasks, val_sets, field, logger, train=args.val_filter)
return FIELD, train_sets, val_sets, aux_sets
return field, train_sets, val_sets, aux_sets
def get_learning_rate(i, args):