Remove all usage of Field from tasks and datasets
This commit is contained in:
parent
aa628629bf
commit
ea83b75089
|
@ -41,6 +41,8 @@ class Example(NamedTuple):
|
|||
question : List[str]
|
||||
answer : List[str]
|
||||
|
||||
vocab_fields = ['context', 'question', 'answer']
|
||||
|
||||
@staticmethod
|
||||
def from_raw(example_id : str, context : str, question : str, answer : str, tokenize, lower=False):
|
||||
args = [[example_id]]
|
||||
|
|
|
@ -43,11 +43,12 @@ from .metrics import compute_metrics
|
|||
from .utils.embeddings import load_embeddings
|
||||
from .tasks.registry import get_tasks
|
||||
from . import models
|
||||
from .text.data import Iterator, ReversibleField
|
||||
from .data.example import Example
|
||||
from .text.data import ReversibleField
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_all_splits(args, new_vocab):
|
||||
def get_all_splits(args, new_field):
|
||||
splits = []
|
||||
for task in args.tasks:
|
||||
logger.info(f'Loading {task}')
|
||||
|
@ -62,8 +63,8 @@ def get_all_splits(args, new_vocab):
|
|||
kwargs['skip_cache_bool'] = args.skip_cache_bool
|
||||
kwargs['cached_path'] = args.cached
|
||||
kwargs['subsample'] = args.subsample
|
||||
s = task.get_splits(new_vocab, root=args.data, **kwargs)[0]
|
||||
preprocess_examples(args, [task], [s], new_vocab, train=False)
|
||||
s = task.get_splits(root=args.data, tokenize=new_field.tokenize, lower=args.lower, **kwargs)[0]
|
||||
preprocess_examples(args, [task], [s], new_field, train=False)
|
||||
splits.append(s)
|
||||
return splits
|
||||
|
||||
|
@ -71,7 +72,7 @@ def get_all_splits(args, new_vocab):
|
|||
def prepare_data(args, FIELD):
|
||||
new_vocab = ReversibleField(batch_first=True, init_token='<init>', eos_token='<eos>', lower=args.lower, include_lengths=True)
|
||||
splits = get_all_splits(args, new_vocab)
|
||||
new_vocab.build_vocab(*splits)
|
||||
new_vocab.build_vocab(Example.vocab_fields, *splits)
|
||||
logger.info(f'Vocabulary has {len(FIELD.vocab)} tokens from training')
|
||||
args.max_generative_vocab = min(len(FIELD.vocab), args.max_generative_vocab)
|
||||
FIELD.append_vocab(new_vocab)
|
||||
|
|
|
@ -46,7 +46,7 @@ class AlmondDataset(generic_dataset.CQA):
|
|||
|
||||
base_url = None
|
||||
|
||||
def __init__(self, path, field, tokenize, contextual=False, reverse_task=False, subsample=None, **kwargs):
|
||||
def __init__(self, path, contextual=False, reverse_task=False, subsample=None, **kwargs):
|
||||
cached_path = kwargs.pop('cached_path')
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
|
||||
|
||||
|
@ -100,23 +100,21 @@ class AlmondDataset(generic_dataset.CQA):
|
|||
|
||||
examples.append(generic_dataset.Example.from_raw('almond/' + _id,
|
||||
context, question, answer,
|
||||
tokenize=str.split, lower=field.lower))
|
||||
tokenize=str.split, lower=False))
|
||||
if len(examples) >= max_examples:
|
||||
break
|
||||
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
|
||||
logger.info(f'Caching data to {cache_name}')
|
||||
torch.save(examples, cache_name)
|
||||
|
||||
super().__init__(examples, field, **kwargs)
|
||||
super().__init__(examples, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def sort_key(ex):
|
||||
return data.interleave_keys(len(ex.context), len(ex.answer))
|
||||
|
||||
@classmethod
|
||||
def splits(cls, fields, root='.data',
|
||||
train='train', validation='eval',
|
||||
test='test', contextual=False, **kwargs):
|
||||
def splits(cls, root='.data', train='train', validation='eval', test='test', contextual=False, **kwargs):
|
||||
|
||||
"""Create dataset objects for splits of the ThingTalk dataset.
|
||||
Arguments:
|
||||
|
@ -138,14 +136,14 @@ class AlmondDataset(generic_dataset.CQA):
|
|||
aux_data = None
|
||||
if kwargs.get('curriculum', False):
|
||||
kwargs.pop('curriculum')
|
||||
aux_data = cls(os.path.join(path, 'aux' + '.tsv'), fields, contextual=contextual, **kwargs)
|
||||
aux_data = cls(os.path.join(path, 'aux' + '.tsv'), contextual=contextual, **kwargs)
|
||||
|
||||
train_data = None if train is None else cls(
|
||||
os.path.join(path, train + '.tsv'), fields, contextual=contextual, **kwargs)
|
||||
os.path.join(path, train + '.tsv'), contextual=contextual, **kwargs)
|
||||
val_data = None if validation is None else cls(
|
||||
os.path.join(path, validation + '.tsv'), fields, contextual=contextual, **kwargs)
|
||||
os.path.join(path, validation + '.tsv'), contextual=contextual, **kwargs)
|
||||
test_data = None if test is None else cls(
|
||||
os.path.join(path, test + '.tsv'), fields, contextual=contextual, **kwargs)
|
||||
os.path.join(path, test + '.tsv'), contextual=contextual, **kwargs)
|
||||
return tuple(d for d in (train_data, val_data, test_data, aux_data)
|
||||
if d is not None)
|
||||
|
||||
|
@ -180,8 +178,7 @@ class Almond(BaseAlmondTask):
|
|||
i.e. natural language to formal language (ThingTalk) mapping"""
|
||||
|
||||
def get_splits(self, field, root, **kwargs):
|
||||
return AlmondDataset.splits(
|
||||
fields=field, root=root, tokenize=self.tokenize, **kwargs)
|
||||
return AlmondDataset.splits(root=root, lower=field.lower, **kwargs)
|
||||
|
||||
|
||||
@register_task('contextual_almond')
|
||||
|
|
|
@ -38,10 +38,11 @@ class Multi30K(BaseTask):
|
|||
def metrics(self):
|
||||
return ['bleu', 'em', 'nem', 'nf1']
|
||||
|
||||
def get_splits(self, field, root, **kwargs):
|
||||
def get_splits(self, root, **kwargs):
|
||||
src, trg = ['.' + x for x in self.name.split('.')[1:]]
|
||||
return generic_dataset.Multi30k.splits(exts=(src, trg),
|
||||
fields=field, root=root, **kwargs)
|
||||
root=root,
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_task('iwslt')
|
||||
|
@ -50,10 +51,11 @@ class IWSLT(BaseTask):
|
|||
def metrics(self):
|
||||
return ['bleu', 'em', 'nem', 'nf1']
|
||||
|
||||
def get_splits(self, field, root, **kwargs):
|
||||
def get_splits(self, root, **kwargs):
|
||||
src, trg = ['.' + x for x in self.name.split('.')[1:]]
|
||||
return generic_dataset.IWSLT.splits(exts=(src, trg),
|
||||
fields=field, root=root, **kwargs)
|
||||
root=root,
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_task('squad')
|
||||
|
@ -62,9 +64,10 @@ class SQuAD(BaseTask):
|
|||
def metrics(self):
|
||||
return ['nf1', 'em', 'nem']
|
||||
|
||||
def get_splits(self, field, root, **kwargs):
|
||||
return generic_dataset.SQuAD.splits(
|
||||
fields=field, root=root, description=self.name, **kwargs)
|
||||
def get_splits(self, root, **kwargs):
|
||||
return generic_dataset.SQuAD.splits(root=root,
|
||||
description=self.name,
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_task('wikisql')
|
||||
|
@ -73,19 +76,22 @@ class WikiSQL(BaseTask):
|
|||
def metrics(self):
|
||||
return ['lfem', 'em', 'nem', 'nf1']
|
||||
|
||||
def get_splits(self, field, root, **kwargs):
|
||||
def get_splits(self, root, **kwargs):
|
||||
return generic_dataset.WikiSQL.splits(
|
||||
fields=field, root=root, query_as_question='query_as_question' in self.name, **kwargs)
|
||||
root=root,
|
||||
query_as_question='query_as_question' in self.name,
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_task('ontonotes')
|
||||
class OntoNotesNER(BaseTask):
|
||||
def get_splits(self, field, root, **kwargs):
|
||||
def get_splits(self, root, **kwargs):
|
||||
split_task = self.name.split('.')
|
||||
_, _, subtask, nones, counting = split_task
|
||||
return generic_dataset.OntoNotesNER.splits(
|
||||
subtask=subtask, nones=True if nones == 'nones' else False,
|
||||
fields=field, root=root, **kwargs)
|
||||
root=root,
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_task('woz')
|
||||
|
@ -94,16 +100,19 @@ class WoZ(BaseTask):
|
|||
def metrics(self):
|
||||
return ['joint_goal_em', 'turn_request_em', 'turn_goal_em', 'avg_dialogue', 'em', 'nem', 'nf1']
|
||||
|
||||
def get_splits(self, field, root, **kwargs):
|
||||
def get_splits(self, root, **kwargs):
|
||||
return generic_dataset.WOZ.splits(description=self.name,
|
||||
fields=field, root=root, **kwargs)
|
||||
root=root,
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_task('multinli')
|
||||
class MultiNLI(BaseTask):
|
||||
def get_splits(self, field, root, **kwargs):
|
||||
def get_splits(self, root, **kwargs):
|
||||
return generic_dataset.MultiNLI.splits(description=self.name,
|
||||
fields=field, root=root, **kwargs)
|
||||
root=root,
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_task('srl')
|
||||
class SRL(BaseTask):
|
||||
|
@ -111,20 +120,20 @@ class SRL(BaseTask):
|
|||
def metrics(self):
|
||||
return ['nf1', 'em', 'nem']
|
||||
|
||||
def get_splits(self, field, root, **kwargs):
|
||||
return generic_dataset.SRL.splits(fields=field, root=root, **kwargs)
|
||||
def get_splits(self, root, **kwargs):
|
||||
return generic_dataset.SRL.splits(root=root, **kwargs)
|
||||
|
||||
|
||||
@register_task('snli')
|
||||
class SNLI(BaseTask):
|
||||
def get_splits(self, field, root, **kwargs):
|
||||
return generic_dataset.SNLI.splits(fields=field, root=root, **kwargs)
|
||||
def get_splits(self, root, **kwargs):
|
||||
return generic_dataset.SNLI.splits(root=root, **kwargs)
|
||||
|
||||
|
||||
@register_task('schema')
|
||||
class WinogradSchema(BaseTask):
|
||||
def get_splits(self, field, root, **kwargs):
|
||||
return generic_dataset.WinogradSchema.splits(fields=field, root=root, **kwargs)
|
||||
def get_splits(self, root, **kwargs):
|
||||
return generic_dataset.WinogradSchema.splits(root=root, **kwargs)
|
||||
|
||||
|
||||
class BaseSummarizationTask(BaseTask):
|
||||
|
@ -142,23 +151,21 @@ class BaseSummarizationTask(BaseTask):
|
|||
|
||||
@register_task('cnn')
|
||||
class CNN(BaseSummarizationTask):
|
||||
def get_splits(self, field, root, **kwargs):
|
||||
return generic_dataset.CNN.splits(fields=field, root=root, **kwargs)
|
||||
def get_splits(self, root, **kwargs):
|
||||
return generic_dataset.CNN.splits(root=root, **kwargs)
|
||||
|
||||
|
||||
@register_task('dailymail')
|
||||
class DailyMail(BaseSummarizationTask):
|
||||
def get_splits(self, field, root, **kwargs):
|
||||
return generic_dataset.DailyMail.splits(fields=field, root=root, **kwargs)
|
||||
def get_splits(self, root, **kwargs):
|
||||
return generic_dataset.DailyMail.splits(root=root, **kwargs)
|
||||
|
||||
|
||||
@register_task('cnn_dailymail')
|
||||
class CNNDailyMail(BaseSummarizationTask):
|
||||
def get_splits(self, field, root, **kwargs):
|
||||
split_cnn = generic_dataset.CNN.splits(
|
||||
fields=field, root=root, **kwargs)
|
||||
split_dm = generic_dataset.DailyMail.splits(
|
||||
fields=field, root=root, **kwargs)
|
||||
def get_splits(self, root, **kwargs):
|
||||
split_cnn = generic_dataset.CNN.splits(root=root, **kwargs)
|
||||
split_dm = generic_dataset.DailyMail.splits(root=root, **kwargs)
|
||||
for scnn, sdm in zip(split_cnn, split_dm):
|
||||
scnn.examples.extend(sdm)
|
||||
return split_cnn
|
||||
|
@ -166,9 +173,8 @@ class CNNDailyMail(BaseSummarizationTask):
|
|||
|
||||
@register_task('sst')
|
||||
class SST(BaseTask):
|
||||
def get_splits(self, field, root, **kwargs):
|
||||
return generic_dataset.SST.splits(
|
||||
fields=field, root=root, **kwargs)
|
||||
def get_splits(self, root, **kwargs):
|
||||
return generic_dataset.SST.splits(root=root, **kwargs)
|
||||
|
||||
|
||||
@register_task('imdb')
|
||||
|
@ -176,9 +182,9 @@ class IMDB(BaseTask):
|
|||
def preprocess_example(self, ex, train=False, max_context_length=None):
|
||||
return ex._replace(context=ex.context[:max_context_length])
|
||||
|
||||
def get_splits(self, field, root, **kwargs):
|
||||
def get_splits(self, root, **kwargs):
|
||||
kwargs['validation'] = None
|
||||
return generic_dataset.IMDb.splits(fields=field, root=root, **kwargs)
|
||||
return generic_dataset.IMDb.splits(root=root, **kwargs)
|
||||
|
||||
|
||||
@register_task('zre')
|
||||
|
@ -187,5 +193,5 @@ class ZRE(BaseTask):
|
|||
def metrics(self):
|
||||
return ['corpus_f1', 'precision', 'recall', 'em', 'nem', 'nf1']
|
||||
|
||||
def get_splits(self, field, root, **kwargs):
|
||||
return generic_dataset.ZeroShotRE.splits(fields=field, root=root, **kwargs)
|
||||
def get_splits(self, root, **kwargs):
|
||||
return generic_dataset.ZeroShotRE.splits(root=root, **kwargs)
|
||||
|
|
|
@ -51,10 +51,6 @@ def make_example_id(dataset, example_id):
|
|||
|
||||
|
||||
class CQA(data.Dataset):
|
||||
def __init__(self, examples, field, **kwargs):
|
||||
fields = [(x, field) for x in Example._fields]
|
||||
super().__init__(examples, fields, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def sort_key(ex):
|
||||
return data.interleave_keys(len(ex.context), len(ex.answer))
|
||||
|
@ -69,7 +65,7 @@ class IMDb(CQA):
|
|||
def sort_key(ex):
|
||||
return data.interleave_keys(len(ex.context), len(ex.answer))
|
||||
|
||||
def __init__(self, path, field, subsample=None, **kwargs):
|
||||
def __init__(self, path, subsample=None, tokenize=None, lower=False, **kwargs):
|
||||
examples = []
|
||||
labels = {'neg': 'negative', 'pos': 'positive'}
|
||||
question = 'Is this review negative or positive?'
|
||||
|
@ -88,30 +84,29 @@ class IMDb(CQA):
|
|||
answer = labels[label]
|
||||
examples.append(Example.from_raw(make_example_id(self, len(examples)),
|
||||
context, question, answer,
|
||||
tokenize=field.tokenize, lower=field.lower))
|
||||
tokenize=tokenize, lower=lower))
|
||||
if subsample is not None and len(examples) > subsample:
|
||||
break
|
||||
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
|
||||
logger.info(f'Caching data to {cache_name}')
|
||||
torch.save(examples, cache_name)
|
||||
super().__init__(examples, field, **kwargs)
|
||||
super().__init__(examples, **kwargs)
|
||||
|
||||
|
||||
@classmethod
|
||||
def splits(cls, fields, root='.data',
|
||||
train='train', validation=None, test='test', **kwargs):
|
||||
def splits(cls, root='.data', train='train', validation=None, test='test', **kwargs):
|
||||
assert validation is None
|
||||
path = cls.download(root)
|
||||
|
||||
aux_data = None
|
||||
if kwargs.get('curriculum', False):
|
||||
kwargs.pop('curriculum')
|
||||
aux_data = cls(os.path.join(path, 'aux'), fields, **kwargs)
|
||||
aux_data = cls(os.path.join(path, 'aux'), **kwargs)
|
||||
|
||||
train_data = None if train is None else cls(
|
||||
os.path.join(path, f'{train}'), fields, **kwargs)
|
||||
os.path.join(path, f'{train}'), **kwargs)
|
||||
test_data = None if test is None else cls(
|
||||
os.path.join(path, f'{test}'), fields, **kwargs)
|
||||
os.path.join(path, f'{test}'), **kwargs)
|
||||
return tuple(d for d in (train_data, test_data, aux_data)
|
||||
if d is not None)
|
||||
|
||||
|
@ -128,7 +123,7 @@ class SST(CQA):
|
|||
def sort_key(ex):
|
||||
return data.interleave_keys(len(ex.context), len(ex.answer))
|
||||
|
||||
def __init__(self, path, field, subsample=None, **kwargs):
|
||||
def __init__(self, path, subsample=None, tokenize=None, lower=False, **kwargs):
|
||||
cached_path = kwargs.pop('cached_path')
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
|
||||
|
||||
|
@ -149,7 +144,7 @@ class SST(CQA):
|
|||
answer = labels[int(parsed[0])]
|
||||
examples.append(Example.from_raw(make_example_id(self, len(examples)),
|
||||
context, question, answer,
|
||||
tokenize=field.tokenize, lower=field.lower))
|
||||
tokenize=tokenize, lower=lower))
|
||||
|
||||
if subsample is not None and len(examples) > subsample:
|
||||
break
|
||||
|
@ -159,25 +154,24 @@ class SST(CQA):
|
|||
torch.save(examples, cache_name)
|
||||
|
||||
self.examples = examples
|
||||
super().__init__(examples, field, **kwargs)
|
||||
super().__init__(examples, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def splits(cls, fields, root='.data',
|
||||
train='train', validation='dev', test='test', **kwargs):
|
||||
def splits(cls, root='.data', train='train', validation='dev', test='test', **kwargs):
|
||||
path = cls.download(root)
|
||||
postfix = f'_binary_sent.csv'
|
||||
|
||||
aux_data = None
|
||||
if kwargs.get('curriculum', False):
|
||||
kwargs.pop('curriculum')
|
||||
aux_data = cls(os.path.join(path, f'aux{postfix}'), fields, **kwargs)
|
||||
aux_data = cls(os.path.join(path, f'aux{postfix}'), **kwargs)
|
||||
|
||||
train_data = None if train is None else cls(
|
||||
os.path.join(path, f'{train}{postfix}'), fields, **kwargs)
|
||||
os.path.join(path, f'{train}{postfix}'), **kwargs)
|
||||
validation_data = None if validation is None else cls(
|
||||
os.path.join(path, f'{validation}{postfix}'), fields, **kwargs)
|
||||
os.path.join(path, f'{validation}{postfix}'), **kwargs)
|
||||
test_data = None if test is None else cls(
|
||||
os.path.join(path, f'{test}{postfix}'), fields, **kwargs)
|
||||
os.path.join(path, f'{test}{postfix}'), **kwargs)
|
||||
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
|
||||
if d is not None)
|
||||
|
||||
|
@ -188,7 +182,7 @@ class TranslationDataset(CQA):
|
|||
def sort_key(ex):
|
||||
return data.interleave_keys(len(ex.context), len(ex.answer))
|
||||
|
||||
def __init__(self, path, exts, field, subsample=None, **kwargs):
|
||||
def __init__(self, path, exts, subsample=None, tokenize=None, lower=False, **kwargs):
|
||||
"""Create a TranslationDataset given paths and fields.
|
||||
|
||||
Arguments:
|
||||
|
@ -220,7 +214,7 @@ class TranslationDataset(CQA):
|
|||
answer = trg_line
|
||||
examples.append(Example.from_raw(make_example_id(self, len(examples)),
|
||||
context, question, answer,
|
||||
tokenize=field.tokenize, lower=field.lower))
|
||||
tokenize=tokenize, lower=lower))
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
break
|
||||
|
||||
|
@ -228,11 +222,10 @@ class TranslationDataset(CQA):
|
|||
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
|
||||
logger.info(f'Caching data to {cache_name}')
|
||||
torch.save(examples, cache_name)
|
||||
super().__init__(examples, field, **kwargs)
|
||||
super().__init__(examples, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def splits(cls, exts, fields, root='.data',
|
||||
train='train', validation='val', test='test', **kwargs):
|
||||
def splits(cls, exts, root='.data', train='train', validation='val', test='test', **kwargs):
|
||||
"""Create dataset objects for splits of a TranslationDataset.
|
||||
|
||||
Arguments:
|
||||
|
@ -252,14 +245,14 @@ class TranslationDataset(CQA):
|
|||
aux_data = None
|
||||
if kwargs.get('curriculum', False):
|
||||
kwargs.pop('curriculum')
|
||||
aux_data = cls(os.path.join(path, 'aux'), exts, fields, **kwargs)
|
||||
aux_data = cls(os.path.join(path, 'aux'), exts, **kwargs)
|
||||
|
||||
train_data = None if train is None else cls(
|
||||
os.path.join(path, train), exts, fields, **kwargs)
|
||||
os.path.join(path, train), exts, **kwargs)
|
||||
val_data = None if validation is None else cls(
|
||||
os.path.join(path, validation), exts, fields, **kwargs)
|
||||
os.path.join(path, validation), exts, **kwargs)
|
||||
test_data = None if test is None else cls(
|
||||
os.path.join(path, test), exts, fields, **kwargs)
|
||||
os.path.join(path, test), exts, **kwargs)
|
||||
return tuple(d for d in (train_data, val_data, test_data, aux_data)
|
||||
if d is not None)
|
||||
|
||||
|
@ -279,7 +272,7 @@ class IWSLT(TranslationDataset, CQA):
|
|||
base_dirname = '{}-{}'
|
||||
|
||||
@classmethod
|
||||
def splits(cls, exts, fields, root='.data',
|
||||
def splits(cls, exts, root='.data',
|
||||
train='train', validation='IWSLT16.TED.tst2013',
|
||||
test='IWSLT16.TED.tst2014', **kwargs):
|
||||
"""Create dataset objects for splits of the IWSLT dataset.
|
||||
|
@ -305,7 +298,7 @@ class IWSLT(TranslationDataset, CQA):
|
|||
if kwargs.get('curriculum', False):
|
||||
kwargs.pop('curriculum')
|
||||
aux = '.'.join(['aux', cls.dirname])
|
||||
aux_data = cls(os.path.join(path, aux), exts, fields, **kwargs)
|
||||
aux_data = cls(os.path.join(path, aux), exts, **kwargs)
|
||||
|
||||
if train is not None:
|
||||
train = '.'.join([train, cls.dirname])
|
||||
|
@ -318,11 +311,11 @@ class IWSLT(TranslationDataset, CQA):
|
|||
cls.clean(path)
|
||||
|
||||
train_data = None if train is None else cls(
|
||||
os.path.join(path, train), exts, fields, **kwargs)
|
||||
os.path.join(path, train), exts, **kwargs)
|
||||
val_data = None if validation is None else cls(
|
||||
os.path.join(path, validation), exts, fields, **kwargs)
|
||||
os.path.join(path, validation), exts, **kwargs)
|
||||
test_data = None if test is None else cls(
|
||||
os.path.join(path, test), exts, fields, **kwargs)
|
||||
os.path.join(path, test), exts, **kwargs)
|
||||
return tuple(d for d in (train_data, val_data, test_data, aux_data)
|
||||
if d is not None)
|
||||
|
||||
|
@ -349,7 +342,7 @@ class IWSLT(TranslationDataset, CQA):
|
|||
fd_txt.write(l.strip() + '\n')
|
||||
|
||||
|
||||
class SQuAD(CQA, data.Dataset):
|
||||
class SQuAD(CQA):
|
||||
|
||||
@staticmethod
|
||||
def sort_key(ex):
|
||||
|
@ -362,7 +355,7 @@ class SQuAD(CQA, data.Dataset):
|
|||
name = 'squad'
|
||||
dirname = ''
|
||||
|
||||
def __init__(self, path, field, subsample=None, **kwargs):
|
||||
def __init__(self, path, subsample=None, lower=False, **kwargs):
|
||||
cached_path = kwargs.pop('cached_path')
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
|
||||
|
||||
|
@ -389,7 +382,7 @@ class SQuAD(CQA, data.Dataset):
|
|||
context = ' '.join(context.split())
|
||||
ex = Example.from_raw(make_example_id(self, qa['id']),
|
||||
context, question, answer,
|
||||
tokenize=str.split, lower=field.lower)
|
||||
tokenize=str.split, lower=lower)
|
||||
else:
|
||||
answer = qa['answers'][0]['text']
|
||||
all_answers.append([a['text'] for a in qa['answers']])
|
||||
|
@ -435,8 +428,7 @@ class SQuAD(CQA, data.Dataset):
|
|||
indexed_answer = tagged_context[context_spans[0]:context_spans[-1]+1]
|
||||
if len(indexed_answer) != len(tokenized_answer):
|
||||
import pdb; pdb.set_trace()
|
||||
if field.eos_token is not None:
|
||||
context_spans += [len(tagged_context)]
|
||||
context_spans += [len(tagged_context)]
|
||||
for context_idx, answer_word in zip(context_spans, ex.answer):
|
||||
if context_idx == len(tagged_context):
|
||||
continue
|
||||
|
@ -445,7 +437,7 @@ class SQuAD(CQA, data.Dataset):
|
|||
|
||||
ex = Example.from_raw(make_example_id(self, qa['id']),
|
||||
' '.join(tagged_context), question, ' '.join(tokenized_answer),
|
||||
tokenize=str.split, lower=field.lower)
|
||||
tokenize=str.split, lower=lower)
|
||||
|
||||
examples.append(ex)
|
||||
if subsample is not None and len(examples) > subsample:
|
||||
|
@ -459,14 +451,13 @@ class SQuAD(CQA, data.Dataset):
|
|||
logger.info(f'Caching data to {cache_name}')
|
||||
torch.save((examples, all_answers, q_ids), cache_name)
|
||||
|
||||
super(SQuAD, self).__init__(examples, field, **kwargs)
|
||||
super(SQuAD, self).__init__(examples, **kwargs)
|
||||
self.all_answers = all_answers
|
||||
self.q_ids = q_ids
|
||||
|
||||
|
||||
@classmethod
|
||||
def splits(cls, fields, root='.data', description='squad1.1',
|
||||
train='train', validation='dev', test=None, **kwargs):
|
||||
def splits(cls, root='.data', description='squad1.1', train='train', validation='dev', test=None, **kwargs):
|
||||
"""Create dataset objects for splits of the SQuAD dataset.
|
||||
Arguments:
|
||||
root: directory containing SQuAD data
|
||||
|
@ -484,15 +475,15 @@ class SQuAD(CQA, data.Dataset):
|
|||
if kwargs.get('curriculum', False):
|
||||
kwargs.pop('curriculum')
|
||||
aux = '-'.join(['aux', extension])
|
||||
aux_data = cls(os.path.join(path, aux), fields, **kwargs)
|
||||
aux_data = cls(os.path.join(path, aux), **kwargs)
|
||||
|
||||
train = '-'.join([train, extension]) if train is not None else None
|
||||
validation = '-'.join([validation, extension]) if validation is not None else None
|
||||
|
||||
train_data = None if train is None else cls(
|
||||
os.path.join(path, train), fields, **kwargs)
|
||||
os.path.join(path, train), **kwargs)
|
||||
validation_data = None if validation is None else cls(
|
||||
os.path.join(path, validation), fields, **kwargs)
|
||||
os.path.join(path, validation), **kwargs)
|
||||
return tuple(d for d in (train_data, validation_data, aux_data)
|
||||
if d is not None)
|
||||
|
||||
|
@ -510,13 +501,13 @@ def fix_missing_period(line):
|
|||
return line + "."
|
||||
|
||||
|
||||
class Summarization(CQA, data.Dataset):
|
||||
class Summarization(CQA):
|
||||
|
||||
@staticmethod
|
||||
def sort_key(ex):
|
||||
return data.interleave_keys(len(ex.context), len(ex.answer))
|
||||
|
||||
def __init__(self, path, field, one_answer=True, subsample=None, **kwargs):
|
||||
def __init__(self, path, one_answer=True, subsample=None, tokenize=None, lower=False, **kwargs):
|
||||
cached_path = kwargs.pop('cached_path')
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
|
||||
|
||||
|
@ -533,14 +524,14 @@ class Summarization(CQA, data.Dataset):
|
|||
context, question, answer = ex['context'], ex['question'], ex['answer']
|
||||
examples.append(Example.from_raw(make_example_id(self, len(examples)),
|
||||
context, question, answer,
|
||||
tokenize=field.tokenize, lower=field.lower))
|
||||
tokenize=tokenize, lower=lower))
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
break
|
||||
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
|
||||
logger.info(f'Caching data to {cache_name}')
|
||||
torch.save(examples, cache_name)
|
||||
|
||||
super(Summarization, self).__init__(examples, field, **kwargs)
|
||||
super(Summarization, self).__init__(examples, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def cache_splits(cls, path):
|
||||
|
@ -593,22 +584,21 @@ class Summarization(CQA, data.Dataset):
|
|||
|
||||
|
||||
@classmethod
|
||||
def splits(cls, fields, root='.data',
|
||||
train='training', validation='validation', test='test', **kwargs):
|
||||
def splits(cls, root='.data', train='training', validation='validation', test='test', **kwargs):
|
||||
path = cls.download(root)
|
||||
cls.cache_splits(path)
|
||||
|
||||
aux_data = None
|
||||
if kwargs.get('curriculum', False):
|
||||
kwargs.pop('curriculum')
|
||||
aux_data = cls(os.path.join(path, 'auxiliary.jsonl'), fields, **kwargs)
|
||||
aux_data = cls(os.path.join(path, 'auxiliary.jsonl'), **kwargs)
|
||||
|
||||
train_data = None if train is None else cls(
|
||||
os.path.join(path, 'training.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, 'training.jsonl'), **kwargs)
|
||||
validation_data = None if validation is None else cls(
|
||||
os.path.join(path, 'validation.jsonl'), fields, one_answer=False, **kwargs)
|
||||
os.path.join(path, 'validation.jsonl'), one_answer=False, **kwargs)
|
||||
test_data = None if test is None else cls(
|
||||
os.path.join(path, 'test.jsonl'), fields, one_answer=False, **kwargs)
|
||||
os.path.join(path, 'test.jsonl'), one_answer=False, **kwargs)
|
||||
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
|
||||
if d is not None)
|
||||
|
||||
|
@ -659,7 +649,7 @@ class Query:
|
|||
return cls(sel_index=d['sel'], agg_index=d['agg'], columns=t, conditions=d['conds'])
|
||||
|
||||
|
||||
class WikiSQL(CQA, data.Dataset):
|
||||
class WikiSQL(CQA):
|
||||
|
||||
@staticmethod
|
||||
def sort_key(ex):
|
||||
|
@ -669,7 +659,7 @@ class WikiSQL(CQA, data.Dataset):
|
|||
name = 'wikisql'
|
||||
dirname = 'data'
|
||||
|
||||
def __init__(self, path, field, query_as_question=False, subsample=None, **kwargs):
|
||||
def __init__(self, path, query_as_question=False, subsample=None, tokenize=None, lower=False, **kwargs):
|
||||
cached_path = kwargs.pop('cached_path')
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', 'query_as_question' if query_as_question else 'query_as_context', os.path.basename(path), str(subsample))
|
||||
skip_cache_bool = kwargs.pop('skip_cache_bool')
|
||||
|
@ -705,7 +695,7 @@ class WikiSQL(CQA, data.Dataset):
|
|||
context += f'-- {human_query}'
|
||||
examples.append(Example.from_raw(make_example_id(self, idx),
|
||||
context, question, answer,
|
||||
tokenize=field.tokenize, lower=field.lower))
|
||||
tokenize=tokenize, lower=lower))
|
||||
all_answers.append({'sql': sql, 'header': header, 'answer': answer, 'table': table})
|
||||
if subsample is not None and len(examples) > subsample:
|
||||
break
|
||||
|
@ -714,13 +704,12 @@ class WikiSQL(CQA, data.Dataset):
|
|||
logger.info(f'Caching data to {cache_name}')
|
||||
torch.save((examples, all_answers), cache_name)
|
||||
|
||||
super(WikiSQL, self).__init__(examples, field, **kwargs)
|
||||
super(WikiSQL, self).__init__(examples, **kwargs)
|
||||
self.all_answers = all_answers
|
||||
|
||||
|
||||
@classmethod
|
||||
def splits(cls, fields, root='.data',
|
||||
train='train.jsonl', validation='dev.jsonl', test='test.jsonl', **kwargs):
|
||||
def splits(cls, root='.data', train='train.jsonl', validation='dev.jsonl', test='test.jsonl', **kwargs):
|
||||
"""Create dataset objects for splits of the SQuAD dataset.
|
||||
Arguments:
|
||||
root: directory containing SQuAD data
|
||||
|
@ -735,19 +724,19 @@ class WikiSQL(CQA, data.Dataset):
|
|||
aux_data = None
|
||||
if kwargs.get('curriculum', False):
|
||||
kwargs.pop('curriculum')
|
||||
aux_data = cls(os.path.join(path, 'aux'), fields, **kwargs)
|
||||
aux_data = cls(os.path.join(path, 'aux'), **kwargs)
|
||||
|
||||
train_data = None if train is None else cls(
|
||||
os.path.join(path, train), fields, **kwargs)
|
||||
os.path.join(path, train), **kwargs)
|
||||
validation_data = None if validation is None else cls(
|
||||
os.path.join(path, validation), fields, **kwargs)
|
||||
os.path.join(path, validation), **kwargs)
|
||||
test_data = None if test is None else cls(
|
||||
os.path.join(path, test), fields, **kwargs)
|
||||
os.path.join(path, test), **kwargs)
|
||||
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
|
||||
if d is not None)
|
||||
|
||||
|
||||
class SRL(CQA, data.Dataset):
|
||||
class SRL(CQA):
|
||||
|
||||
@staticmethod
|
||||
def sort_key(ex):
|
||||
|
@ -762,9 +751,10 @@ class SRL(CQA, data.Dataset):
|
|||
|
||||
@classmethod
|
||||
def clean(cls, s):
|
||||
closing_punctuation = set([ ' .', ' ,', ' ;', ' !', ' ?', ' :', ' )', " 'll", " n't ", " %", " 't", " 's", " 'm", " 'd", " 're"])
|
||||
opening_punctuation = set(['( ', '$ '])
|
||||
both_sides = set([' - '])
|
||||
closing_punctuation = {' .', ' ,', ' ;', ' !', ' ?', ' :', ' )', " 'll", " n't ", " %", " 't", " 's", " 'm",
|
||||
" 'd", " 're"}
|
||||
opening_punctuation = {'( ', '$ '}
|
||||
both_sides = {' - '}
|
||||
s = ' '.join(s.split()).strip()
|
||||
s = s.replace('-LRB-', '(')
|
||||
s = s.replace('-RRB-', ')')
|
||||
|
@ -787,7 +777,7 @@ class SRL(CQA, data.Dataset):
|
|||
s = s.replace(" '", '')
|
||||
return ' '.join(s.split()).strip()
|
||||
|
||||
def __init__(self, path, field, one_answer=True, subsample=None, **kwargs):
|
||||
def __init__(self, path, one_answer=True, subsample=None, tokenize=None, lower=False, **kwargs):
|
||||
cached_path = kwargs.pop('cached_path')
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
|
||||
|
||||
|
@ -805,7 +795,7 @@ class SRL(CQA, data.Dataset):
|
|||
context, question, answer = ex['context'], ex['question'], ex['answer']
|
||||
examples.append(Example.from_raw(make_example_id(self, len(all_answers)),
|
||||
context, question, answer,
|
||||
tokenize=field.tokenize, lower=field.lower))
|
||||
tokenize=tokenize, lower=lower))
|
||||
all_answers.append(aa)
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
break
|
||||
|
@ -813,7 +803,7 @@ class SRL(CQA, data.Dataset):
|
|||
logger.info(f'Caching data to {cache_name}')
|
||||
torch.save((examples, all_answers), cache_name)
|
||||
|
||||
super(SRL, self).__init__(examples, field, **kwargs)
|
||||
super(SRL, self).__init__(examples, **kwargs)
|
||||
self.all_answers = all_answers
|
||||
|
||||
|
||||
|
@ -909,22 +899,21 @@ class SRL(CQA, data.Dataset):
|
|||
|
||||
|
||||
@classmethod
|
||||
def splits(cls, fields, root='.data',
|
||||
train='train', validation='dev', test='test', **kwargs):
|
||||
def splits(cls, root='.data', train='train', validation='dev', test='test', **kwargs):
|
||||
path = cls.download(root)
|
||||
cls.cache_splits(path)
|
||||
|
||||
aux_data = None
|
||||
if kwargs.get('curriculum', False):
|
||||
kwargs.pop('curriculum')
|
||||
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
|
||||
aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs)
|
||||
|
||||
train_data = None if train is None else cls(
|
||||
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, f'{train}.jsonl'), **kwargs)
|
||||
validation_data = None if validation is None else cls(
|
||||
os.path.join(path, f'{validation}.jsonl'), fields, one_answer=False, **kwargs)
|
||||
os.path.join(path, f'{validation}.jsonl'), one_answer=False, **kwargs)
|
||||
test_data = None if test is None else cls(
|
||||
os.path.join(path, f'{test}.jsonl'), fields, one_answer=False, **kwargs)
|
||||
os.path.join(path, f'{test}.jsonl'), one_answer=False, **kwargs)
|
||||
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
|
||||
if d is not None)
|
||||
|
||||
|
@ -940,7 +929,7 @@ class WinogradSchema(CQA, data.Dataset):
|
|||
name = 'schema'
|
||||
dirname = ''
|
||||
|
||||
def __init__(self, path, field, subsample=None, **kwargs):
|
||||
def __init__(self, path, subsample=None, tokenize=None, lower=False, **kwargs):
|
||||
cached_path = kwargs.pop('cached_path')
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
|
||||
skip_cache_bool = kwargs.pop('skip_cache_bool')
|
||||
|
@ -955,14 +944,14 @@ class WinogradSchema(CQA, data.Dataset):
|
|||
context, question, answer = ex['context'], ex['question'], ex['answer']
|
||||
examples.append(Example.from_raw(make_example_id(self, len(examples)),
|
||||
context, question, answer,
|
||||
tokenize=field.tokenize, lower=field.lower))
|
||||
tokenize=tokenize, lower=lower))
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
break
|
||||
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
|
||||
logger.info(f'Caching data to {cache_name}')
|
||||
torch.save(examples, cache_name)
|
||||
|
||||
super(WinogradSchema, self).__init__(examples, field, **kwargs)
|
||||
super(WinogradSchema, self).__init__(examples, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def cache_splits(cls, path):
|
||||
|
@ -1021,22 +1010,21 @@ class WinogradSchema(CQA, data.Dataset):
|
|||
|
||||
|
||||
@classmethod
|
||||
def splits(cls, fields, root='.data',
|
||||
train='train', validation='validation', test='test', **kwargs):
|
||||
def splits(cls, root='.data', train='train', validation='validation', test='test', **kwargs):
|
||||
path = cls.download(root)
|
||||
cls.cache_splits(path)
|
||||
|
||||
aux_data = None
|
||||
if kwargs.get('curriculum', False):
|
||||
kwargs.pop('curriculum')
|
||||
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
|
||||
aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs)
|
||||
|
||||
train_data = None if train is None else cls(
|
||||
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, f'{train}.jsonl'), **kwargs)
|
||||
validation_data = None if validation is None else cls(
|
||||
os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, f'{validation}.jsonl'), **kwargs)
|
||||
test_data = None if test is None else cls(
|
||||
os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, f'{test}.jsonl'), **kwargs)
|
||||
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
|
||||
if d is not None)
|
||||
|
||||
|
@ -1058,10 +1046,11 @@ class WOZ(CQA, data.Dataset):
|
|||
name = 'woz'
|
||||
dirname = ''
|
||||
|
||||
def __init__(self, path, field, subsample=None, description='woz.en', **kwargs):
|
||||
def __init__(self, path, subsample=None, tokenize=None, lower=False, description='woz.en', **kwargs):
|
||||
examples, all_answers = [], []
|
||||
cached_path = kwargs.pop('cached_path')
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample), description)
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path),
|
||||
str(subsample), description)
|
||||
skip_cache_bool = kwargs.pop('skip_cache_bool')
|
||||
if os.path.exists(cache_name) and not skip_cache_bool:
|
||||
logger.info(f'Loading cached data from {cache_name}')
|
||||
|
@ -1075,7 +1064,7 @@ class WOZ(CQA, data.Dataset):
|
|||
all_answers.append((ex['lang_dialogue_turn'], answer))
|
||||
examples.append(Example.from_raw(make_example_id(self, woz_id),
|
||||
context, question, answer,
|
||||
tokenize=field.tokenize, lower=field.lower))
|
||||
tokenize=tokenize, lower=lower))
|
||||
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
break
|
||||
|
@ -1083,7 +1072,7 @@ class WOZ(CQA, data.Dataset):
|
|||
logger.info(f'Caching data to {cache_name}')
|
||||
torch.save((examples, all_answers), cache_name)
|
||||
|
||||
super(WOZ, self).__init__(examples, field, **kwargs)
|
||||
super(WOZ, self).__init__(examples, **kwargs)
|
||||
self.all_answers = all_answers
|
||||
|
||||
@classmethod
|
||||
|
@ -1150,26 +1139,26 @@ class WOZ(CQA, data.Dataset):
|
|||
|
||||
|
||||
@classmethod
|
||||
def splits(cls, fields, root='.data', train='train', validation='validate', test='test', **kwargs):
|
||||
def splits(cls, root='.data', train='train', validation='validate', test='test', **kwargs):
|
||||
path = cls.download(root)
|
||||
cls.cache_splits(path)
|
||||
|
||||
aux_data = None
|
||||
if kwargs.get('curriculum', False):
|
||||
kwargs.pop('curriculum')
|
||||
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
|
||||
aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs)
|
||||
|
||||
train_data = None if train is None else cls(
|
||||
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, f'{train}.jsonl'), **kwargs)
|
||||
validation_data = None if validation is None else cls(
|
||||
os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, f'{validation}.jsonl'), **kwargs)
|
||||
test_data = None if test is None else cls(
|
||||
os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, f'{test}.jsonl'), **kwargs)
|
||||
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
|
||||
if d is not None)
|
||||
|
||||
|
||||
class MultiNLI(CQA, data.Dataset):
|
||||
class MultiNLI(CQA):
|
||||
|
||||
@staticmethod
|
||||
def sort_key(ex):
|
||||
|
@ -1180,9 +1169,10 @@ class MultiNLI(CQA, data.Dataset):
|
|||
name = 'multinli'
|
||||
dirname = 'multinli_1.0'
|
||||
|
||||
def __init__(self, path, field, subsample=None, description='multinli.in.out', **kwargs):
|
||||
def __init__(self, path, subsample=None, tokenize=None, lower=False, description='multinli.in.out', **kwargs):
|
||||
cached_path = kwargs.pop('cached_path')
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample), description)
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path),
|
||||
str(subsample), description)
|
||||
skip_cache_bool = kwargs.pop('skip_cache_bool')
|
||||
if os.path.exists(cache_name) and not skip_cache_bool:
|
||||
logger.info(f'Loading cached data from {cache_name}')
|
||||
|
@ -1196,14 +1186,14 @@ class MultiNLI(CQA, data.Dataset):
|
|||
context, question, answer = ex['context'], ex['question'], ex['answer']
|
||||
examples.append(Example.from_raw(make_example_id(self, len(examples)),
|
||||
context, question, answer,
|
||||
tokenize=field.tokenize, lower=field.lower))
|
||||
tokenize=tokenize, lower=lower))
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
break
|
||||
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
|
||||
logger.info(f'Caching data to {cache_name}')
|
||||
torch.save(examples, cache_name)
|
||||
|
||||
super(MultiNLI, self).__init__(examples, field, **kwargs)
|
||||
super(MultiNLI, self).__init__(examples, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def cache_splits(cls, path, train='multinli_1.0_train', validation='mulinli_1.0_dev_{}', test='test'):
|
||||
|
@ -1234,26 +1224,26 @@ class MultiNLI(CQA, data.Dataset):
|
|||
|
||||
|
||||
@classmethod
|
||||
def splits(cls, fields, root='.data', train='train', validation='validation', test='test', **kwargs):
|
||||
def splits(cls, root='.data', train='train', validation='validation', test='test', **kwargs):
|
||||
path = cls.download(root)
|
||||
cls.cache_splits(path)
|
||||
|
||||
aux_data = None
|
||||
if kwargs.get('curriculum', False):
|
||||
kwargs.pop('curriculum')
|
||||
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
|
||||
aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs)
|
||||
|
||||
train_data = None if train is None else cls(
|
||||
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, f'{train}.jsonl'), **kwargs)
|
||||
validation_data = None if validation is None else cls(
|
||||
os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, f'{validation}.jsonl'), **kwargs)
|
||||
test_data = None if test is None else cls(
|
||||
os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, f'{test}.jsonl'), **kwargs)
|
||||
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
|
||||
if d is not None)
|
||||
|
||||
|
||||
class ZeroShotRE(CQA, data.Dataset):
|
||||
class ZeroShotRE(CQA):
|
||||
|
||||
@staticmethod
|
||||
def sort_key(ex):
|
||||
|
@ -1264,9 +1254,10 @@ class ZeroShotRE(CQA, data.Dataset):
|
|||
name = 'zre'
|
||||
|
||||
|
||||
def __init__(self, path, field, subsample=None, **kwargs):
|
||||
def __init__(self, path, subsample=None, tokenize=None, lower=False, **kwargs):
|
||||
cached_path = kwargs.pop('cached_path')
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path),
|
||||
str(subsample))
|
||||
skip_cache_bool = kwargs.pop('skip_cache_bool')
|
||||
if os.path.exists(cache_name) and not skip_cache_bool:
|
||||
logger.info(f'Loading cached data from {cache_name}')
|
||||
|
@ -1279,7 +1270,7 @@ class ZeroShotRE(CQA, data.Dataset):
|
|||
context, question, answer = ex['context'], ex['question'], ex['answer']
|
||||
examples.append(Example.from_raw(make_example_id(self, len(examples)),
|
||||
context, question, answer,
|
||||
tokenize=field.tokenize, lower=field.lower))
|
||||
tokenize=tokenize, lower=lower))
|
||||
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
break
|
||||
|
@ -1287,7 +1278,7 @@ class ZeroShotRE(CQA, data.Dataset):
|
|||
logger.info(f'Caching data to {cache_name}')
|
||||
torch.save(examples, cache_name)
|
||||
|
||||
super().__init__(examples, field, **kwargs)
|
||||
super().__init__(examples, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def cache_splits(cls, path, train='train', validation='dev', test='test'):
|
||||
|
@ -1316,21 +1307,21 @@ class ZeroShotRE(CQA, data.Dataset):
|
|||
|
||||
|
||||
@classmethod
|
||||
def splits(cls, fields, root='.data', train='train', validation='dev', test='test', **kwargs):
|
||||
def splits(cls, root='.data', train='train', validation='dev', test='test', **kwargs):
|
||||
path = cls.download(root)
|
||||
cls.cache_splits(path)
|
||||
|
||||
aux_data = None
|
||||
if kwargs.get('curriculum', False):
|
||||
kwargs.pop('curriculum')
|
||||
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
|
||||
aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs)
|
||||
|
||||
train_data = None if train is None else cls(
|
||||
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, f'{train}.jsonl'), **kwargs)
|
||||
validation_data = None if validation is None else cls(
|
||||
os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, f'{validation}.jsonl'), **kwargs)
|
||||
test_data = None if test is None else cls(
|
||||
os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, f'{test}.jsonl'), **kwargs)
|
||||
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
|
||||
if d is not None)
|
||||
|
||||
|
@ -1396,7 +1387,9 @@ class OntoNotesNER(CQA, data.Dataset):
|
|||
|
||||
return ' '.join(raw.split()).strip()
|
||||
|
||||
def __init__(self, path, field, one_answer=True, subsample=None, path_to_files='.data/ontonotes-release-5.0/data/files', subtask='all', nones=True, **kwargs):
|
||||
def __init__(self, path, one_answer=True, subsample=None, tokenize=None, lower=False,
|
||||
path_to_files='.data/ontonotes-release-5.0/data/files',
|
||||
subtask='all', nones=True, **kwargs):
|
||||
cached_path = kwargs.pop('cached_path')
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample), subtask, str(nones))
|
||||
skip_cache_bool = kwargs.pop('skip_cache_bool')
|
||||
|
@ -1416,7 +1409,7 @@ class OntoNotesNER(CQA, data.Dataset):
|
|||
context, question, answer = ex['context'], ex['question'], ex['answer']
|
||||
examples.append(Example.from_raw(make_example_id(self, len(examples)),
|
||||
context, question, answer,
|
||||
tokenize=field.tokenize, lower=field.lower))
|
||||
tokenize=tokenize, lower=lower))
|
||||
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
break
|
||||
|
@ -1424,7 +1417,7 @@ class OntoNotesNER(CQA, data.Dataset):
|
|||
logger.info(f'Caching data to {cache_name}')
|
||||
torch.save(examples, cache_name)
|
||||
|
||||
super(OntoNotesNER, self).__init__(examples, field, **kwargs)
|
||||
super(OntoNotesNER, self).__init__(examples, **kwargs)
|
||||
|
||||
|
||||
@classmethod
|
||||
|
@ -1554,8 +1547,7 @@ class OntoNotesNER(CQA, data.Dataset):
|
|||
|
||||
|
||||
@classmethod
|
||||
def splits(cls, fields, root='.data',
|
||||
train='train', validation='development', test='test', **kwargs):
|
||||
def splits(cls, root='.data', train='train', validation='development', test='test', **kwargs):
|
||||
path_to_files = os.path.join(root, 'ontonotes-release-5.0', 'data', 'files')
|
||||
assert os.path.exists(path_to_files)
|
||||
path = cls.download(root)
|
||||
|
@ -1564,18 +1556,18 @@ class OntoNotesNER(CQA, data.Dataset):
|
|||
aux_data = None
|
||||
if kwargs.get('curriculum', False):
|
||||
kwargs.pop('curriculum')
|
||||
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
|
||||
aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs)
|
||||
|
||||
train_data = None if train is None else cls(
|
||||
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, f'{train}.jsonl'), **kwargs)
|
||||
validation_data = None if validation is None else cls(
|
||||
os.path.join(path, f'{validation}.jsonl'), fields, one_answer=False, **kwargs)
|
||||
os.path.join(path, f'{validation}.jsonl'), one_answer=False, **kwargs)
|
||||
test_data = None if test is None else cls(
|
||||
os.path.join(path, f'{test}.jsonl'), fields, one_answer=False, **kwargs)
|
||||
os.path.join(path, f'{test}.jsonl'), one_answer=False, **kwargs)
|
||||
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
|
||||
if d is not None)
|
||||
|
||||
class SNLI(CQA, data.Dataset):
|
||||
class SNLI(CQA):
|
||||
|
||||
@staticmethod
|
||||
def sort_key(ex):
|
||||
|
@ -1586,9 +1578,10 @@ class SNLI(CQA, data.Dataset):
|
|||
name = 'snli'
|
||||
|
||||
|
||||
def __init__(self, path, field, subsample=None, **kwargs):
|
||||
def __init__(self, path, subsample=None, tokenize=None, lower=False, **kwargs):
|
||||
cached_path = kwargs.pop('cached_path')
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path),
|
||||
str(subsample))
|
||||
skip_cache_bool = kwargs.pop('skip_cache_bool')
|
||||
if os.path.exists(cache_name) and not skip_cache_bool:
|
||||
logger.info(f'Loading cached data from {cache_name}')
|
||||
|
@ -1602,7 +1595,7 @@ class SNLI(CQA, data.Dataset):
|
|||
context, question, answer = ex['context'], ex['question'], ex['answer']
|
||||
examples.append(Example.from_raw(make_example_id(self, len(examples)),
|
||||
context, question, answer,
|
||||
tokenize=field.tokenize, lower=field.lower))
|
||||
tokenize=tokenize, lower=lower))
|
||||
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
break
|
||||
|
@ -1610,7 +1603,7 @@ class SNLI(CQA, data.Dataset):
|
|||
logger.info(f'Caching data to {cache_name}')
|
||||
torch.save(examples, cache_name)
|
||||
|
||||
super().__init__(examples, field, **kwargs)
|
||||
super().__init__(examples, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def cache_splits(cls, path, train='train', validation='dev', test='test'):
|
||||
|
@ -1632,34 +1625,36 @@ class SNLI(CQA, data.Dataset):
|
|||
|
||||
|
||||
@classmethod
|
||||
def splits(cls, fields, root='.data', train='train', validation='dev', test='test', **kwargs):
|
||||
def splits(cls, root='.data', train='train', validation='dev', test='test', **kwargs):
|
||||
path = cls.download(root)
|
||||
cls.cache_splits(path)
|
||||
|
||||
aux_data = None
|
||||
if kwargs.get('curriculum', False):
|
||||
kwargs.pop('curriculum')
|
||||
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
|
||||
aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs)
|
||||
|
||||
train_data = None if train is None else cls(
|
||||
os.path.join(path, f'{train}.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, f'{train}.jsonl'), **kwargs)
|
||||
validation_data = None if validation is None else cls(
|
||||
os.path.join(path, f'{validation}.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, f'{validation}.jsonl'), **kwargs)
|
||||
test_data = None if test is None else cls(
|
||||
os.path.join(path, f'{test}.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, f'{test}.jsonl'), **kwargs)
|
||||
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
|
||||
if d is not None)
|
||||
|
||||
|
||||
class JSON(CQA, data.Dataset):
|
||||
class JSON(CQA):
|
||||
name = 'json'
|
||||
|
||||
@staticmethod
|
||||
def sort_key(ex):
|
||||
return data.interleave_keys(len(ex.context), len(ex.answer))
|
||||
|
||||
def __init__(self, path, field, subsample=None, **kwargs):
|
||||
def __init__(self, path, subsample=None, tokenize=None, lower=False, **kwargs):
|
||||
cached_path = kwargs.pop('cached_path')
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample))
|
||||
cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path),
|
||||
str(subsample))
|
||||
|
||||
examples = []
|
||||
skip_cache_bool = kwargs.pop('skip_cache_bool')
|
||||
|
@ -1674,30 +1669,29 @@ class JSON(CQA, data.Dataset):
|
|||
context, question, answer = ex['context'], ex['question'], ex['answer']
|
||||
examples.append(Example.from_raw(make_example_id(self, len(examples)),
|
||||
context, question, answer,
|
||||
tokenize=field.tokenize, lower=field.lower))
|
||||
tokenize=tokenize, lower=lower))
|
||||
if subsample is not None and len(examples) >= subsample:
|
||||
break
|
||||
os.makedirs(os.path.dirname(cache_name), exist_ok=True)
|
||||
logger.info(f'Caching data to {cache_name}')
|
||||
torch.save(examples, cache_name)
|
||||
|
||||
super(JSON, self).__init__(examples, field, **kwargs)
|
||||
super(JSON, self).__init__(examples, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def splits(cls, fields, name, root='.data',
|
||||
train='train', validation='val', test='test', **kwargs):
|
||||
def splits(cls, name, root='.data', train='train', validation='val', test='test', **kwargs):
|
||||
path = os.path.join(root, name)
|
||||
|
||||
aux_data = None
|
||||
if kwargs.get('curriculum', False):
|
||||
kwargs.pop('curriculum')
|
||||
aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs)
|
||||
aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs)
|
||||
|
||||
train_data = None if train is None else cls(
|
||||
os.path.join(path, 'train.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, 'train.jsonl'), **kwargs)
|
||||
validation_data = None if validation is None else cls(
|
||||
os.path.join(path, 'val.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, 'val.jsonl'), **kwargs)
|
||||
test_data = None if test is None else cls(
|
||||
os.path.join(path, 'test.jsonl'), fields, **kwargs)
|
||||
os.path.join(path, 'test.jsonl'), **kwargs)
|
||||
return tuple(d for d in (train_data, validation_data, test_data, aux_data)
|
||||
if d is not None)
|
||||
|
|
|
@ -18,19 +18,14 @@ class Dataset(torch.utils.data.Dataset):
|
|||
fields: A dictionary containing the name of each column together with
|
||||
its corresponding Field object. Two columns with the same Field
|
||||
object will share a vocabulary.
|
||||
fields (dict[str, Field]): Contains the name of each column or field, together
|
||||
with the corresponding Field object. Two fields with the same Field object
|
||||
will have a shared vocabulary.
|
||||
"""
|
||||
sort_key = None
|
||||
|
||||
def __init__(self, examples, fields, filter_pred=None):
|
||||
def __init__(self, examples, filter_pred=None, **kwargs):
|
||||
"""Create a dataset from a list of Examples and Fields.
|
||||
|
||||
Arguments:
|
||||
examples: List of Examples.
|
||||
fields (List(tuple(str, Field))): The Fields to use in this tuple. The
|
||||
string is a field name, and the Field is the associated field.
|
||||
filter_pred (callable or None): Use only examples for which
|
||||
filter_pred(example) is True, or use all examples if None.
|
||||
Default is None.
|
||||
|
@ -41,16 +36,13 @@ class Dataset(torch.utils.data.Dataset):
|
|||
if make_list:
|
||||
examples = list(examples)
|
||||
self.examples = examples
|
||||
self.fields = dict(fields)
|
||||
|
||||
@classmethod
|
||||
def splits(cls, path=None, root='.data', train=None, validation=None,
|
||||
def splits(cls, root='.data', train=None, validation=None,
|
||||
test=None, **kwargs):
|
||||
"""Create Dataset objects for multiple splits of a dataset.
|
||||
|
||||
Arguments:
|
||||
path (str): Common prefix of the splits' file paths, or None to use
|
||||
the result of cls.download(root).
|
||||
root (str): Root dataset storage directory. Default is '.data'.
|
||||
train (str): Suffix to add to path for the train set, or None for no
|
||||
train set. Default is None.
|
||||
|
@ -65,8 +57,7 @@ class Dataset(torch.utils.data.Dataset):
|
|||
split_datasets (tuple(Dataset)): Datasets for train, validation, and
|
||||
test splits in that order, if provided.
|
||||
"""
|
||||
if path is None:
|
||||
path = cls.download(root)
|
||||
path = cls.download(root)
|
||||
train_data = None if train is None else cls(
|
||||
os.path.join(path, train), **kwargs)
|
||||
val_data = None if validation is None else cls(
|
||||
|
@ -89,11 +80,6 @@ class Dataset(torch.utils.data.Dataset):
|
|||
for x in self.examples:
|
||||
yield x
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.fields:
|
||||
for x in self.examples:
|
||||
yield getattr(x, attr)
|
||||
|
||||
@classmethod
|
||||
def download(cls, root, check=None):
|
||||
"""Download and unzip an online archive (.zip, .gz, or .tgz).
|
||||
|
|
|
@ -244,7 +244,7 @@ class Field(RawField):
|
|||
return (padded, lengths)
|
||||
return padded
|
||||
|
||||
def build_vocab(self, *args, **kwargs):
|
||||
def build_vocab(self, field_names, *args, **kwargs):
|
||||
"""Construct the Vocab object for this field from one or more datasets.
|
||||
|
||||
Arguments:
|
||||
|
@ -259,11 +259,7 @@ class Field(RawField):
|
|||
counter = Counter()
|
||||
sources = []
|
||||
for arg in args:
|
||||
if hasattr(arg, 'fields'):
|
||||
sources += [getattr(arg, name) for name, field in
|
||||
arg.fields.items() if field is self]
|
||||
else:
|
||||
sources.append(arg)
|
||||
sources += [getattr(ex, name) for name in field_names for ex in arg]
|
||||
for data in sources:
|
||||
for x in data:
|
||||
if not self.sequential:
|
||||
|
|
|
@ -34,15 +34,11 @@ import math
|
|||
import time
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
|
||||
import logging
|
||||
from pprint import pformat
|
||||
from logging import handlers
|
||||
import numpy as np
|
||||
from .utils.model_utils import init_model
|
||||
|
||||
import torch
|
||||
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from . import arguments
|
||||
|
@ -53,6 +49,8 @@ from .utils.saver import Saver
|
|||
from .utils.embeddings import load_embeddings
|
||||
from .text.data import ReversibleField
|
||||
from .data.numericalizer import DecoderVocabulary
|
||||
from .data.example import Example
|
||||
from .utils.model_utils import init_model
|
||||
|
||||
|
||||
def initialize_logger(args, rank='main'):
|
||||
|
@ -78,12 +76,9 @@ def log(rank='main'):
|
|||
|
||||
|
||||
def prepare_data(args, field, logger):
|
||||
|
||||
if field is None:
|
||||
logger.info(f'Constructing field')
|
||||
FIELD = ReversibleField(batch_first=True, init_token='<init>', eos_token='<eos>', lower=args.lower, include_lengths=True)
|
||||
else:
|
||||
FIELD = field
|
||||
field = ReversibleField(batch_first=True, init_token='<init>', eos_token='<eos>', lower=args.lower, include_lengths=True)
|
||||
|
||||
train_sets, val_sets, aux_sets, vocab_sets = [], [], [], []
|
||||
for task in args.train_tasks:
|
||||
|
@ -97,7 +92,7 @@ def prepare_data(args, field, logger):
|
|||
kwargs['cached_path'] = args.cached
|
||||
|
||||
logger.info(f'Adding {task.name} to training datasets')
|
||||
split = task.get_splits(FIELD, args.data, **kwargs)
|
||||
split = task.get_splits(args.data, tokenize=field.tokenize, lower=args.lower, **kwargs)
|
||||
if args.use_curriculum:
|
||||
assert len(split) == 2
|
||||
aux_sets.append(split[1])
|
||||
|
@ -118,7 +113,7 @@ def prepare_data(args, field, logger):
|
|||
kwargs['cached_path'] = args.cached
|
||||
|
||||
logger.info(f'Adding {task.name} to validation datasets')
|
||||
split = task.get_splits(FIELD, args.data, **kwargs)
|
||||
split = task.get_splits(args.data, tokenize=field.tokenize, lower=args.lower, **kwargs)
|
||||
assert len(split) == 1
|
||||
logger.info(f'{task.name} has {len(split[0])} validation examples')
|
||||
val_sets.append(split[0])
|
||||
|
@ -129,23 +124,23 @@ def prepare_data(args, field, logger):
|
|||
vectors = load_embeddings(args, logger)
|
||||
vocab_sets = (train_sets + val_sets) if len(vocab_sets) == 0 else vocab_sets
|
||||
logger.info(f'Building vocabulary')
|
||||
FIELD.build_vocab(*vocab_sets, max_size=args.max_effective_vocab, vectors=vectors)
|
||||
field.build_vocab(Example.vocab_fields, *vocab_sets, max_size=args.max_effective_vocab, vectors=vectors)
|
||||
|
||||
FIELD.decoder_vocab = DecoderVocabulary(FIELD.vocab.itos[:args.max_generative_vocab], FIELD.vocab)
|
||||
field.decoder_vocab = DecoderVocabulary(field.vocab.itos[:args.max_generative_vocab], field.vocab)
|
||||
|
||||
logger.info(f'Vocabulary has {len(FIELD.vocab)} tokens')
|
||||
logger.info(f'Vocabulary has {len(field.vocab)} tokens')
|
||||
logger.debug(f'The first 200 tokens:')
|
||||
logger.debug(FIELD.vocab.itos[:200])
|
||||
logger.debug(field.vocab.itos[:200])
|
||||
|
||||
if args.use_curriculum:
|
||||
logger.info('Preprocessing auxiliary data for curriculum')
|
||||
preprocess_examples(args, args.train_tasks, aux_sets, FIELD, logger, train=True)
|
||||
preprocess_examples(args, args.train_tasks, aux_sets, field, logger, train=True)
|
||||
logger.info('Preprocessing training data')
|
||||
preprocess_examples(args, args.train_tasks, train_sets, FIELD, logger, train=True)
|
||||
preprocess_examples(args, args.train_tasks, train_sets, field, logger, train=True)
|
||||
logger.info('Preprocessing validation data')
|
||||
preprocess_examples(args, args.val_tasks, val_sets, FIELD, logger, train=args.val_filter)
|
||||
preprocess_examples(args, args.val_tasks, val_sets, field, logger, train=args.val_filter)
|
||||
|
||||
return FIELD, train_sets, val_sets, aux_sets
|
||||
return field, train_sets, val_sets, aux_sets
|
||||
|
||||
|
||||
def get_learning_rate(i, args):
|
||||
|
|
Loading…
Reference in New Issue