From ea83b75089f101764f4ef702e30a681e8b5674cf Mon Sep 17 00:00:00 2001 From: Giovanni Campagna Date: Sat, 21 Dec 2019 07:36:28 -0800 Subject: [PATCH] Remove all usage of Field from tasks and datasets --- decanlp/data/example.py | 2 + decanlp/predict.py | 11 +- decanlp/tasks/almond/__init__.py | 21 +-- decanlp/tasks/generic.py | 80 +++++---- decanlp/tasks/generic_dataset.py | 290 +++++++++++++++---------------- decanlp/text/data/dataset.py | 20 +-- decanlp/text/data/field.py | 8 +- decanlp/train.py | 31 ++-- 8 files changed, 220 insertions(+), 243 deletions(-) diff --git a/decanlp/data/example.py b/decanlp/data/example.py index 19bed54b..f7aee4ab 100644 --- a/decanlp/data/example.py +++ b/decanlp/data/example.py @@ -41,6 +41,8 @@ class Example(NamedTuple): question : List[str] answer : List[str] + vocab_fields = ['context', 'question', 'answer'] + @staticmethod def from_raw(example_id : str, context : str, question : str, answer : str, tokenize, lower=False): args = [[example_id]] diff --git a/decanlp/predict.py b/decanlp/predict.py index 56a4388c..40b7d414 100644 --- a/decanlp/predict.py +++ b/decanlp/predict.py @@ -43,11 +43,12 @@ from .metrics import compute_metrics from .utils.embeddings import load_embeddings from .tasks.registry import get_tasks from . import models -from .text.data import Iterator, ReversibleField +from .data.example import Example +from .text.data import ReversibleField logger = logging.getLogger(__name__) -def get_all_splits(args, new_vocab): +def get_all_splits(args, new_field): splits = [] for task in args.tasks: logger.info(f'Loading {task}') @@ -62,8 +63,8 @@ def get_all_splits(args, new_vocab): kwargs['skip_cache_bool'] = args.skip_cache_bool kwargs['cached_path'] = args.cached kwargs['subsample'] = args.subsample - s = task.get_splits(new_vocab, root=args.data, **kwargs)[0] - preprocess_examples(args, [task], [s], new_vocab, train=False) + s = task.get_splits(root=args.data, tokenize=new_field.tokenize, lower=args.lower, **kwargs)[0] + preprocess_examples(args, [task], [s], new_field, train=False) splits.append(s) return splits @@ -71,7 +72,7 @@ def get_all_splits(args, new_vocab): def prepare_data(args, FIELD): new_vocab = ReversibleField(batch_first=True, init_token='', eos_token='', lower=args.lower, include_lengths=True) splits = get_all_splits(args, new_vocab) - new_vocab.build_vocab(*splits) + new_vocab.build_vocab(Example.vocab_fields, *splits) logger.info(f'Vocabulary has {len(FIELD.vocab)} tokens from training') args.max_generative_vocab = min(len(FIELD.vocab), args.max_generative_vocab) FIELD.append_vocab(new_vocab) diff --git a/decanlp/tasks/almond/__init__.py b/decanlp/tasks/almond/__init__.py index c6b10e17..0ada1e77 100644 --- a/decanlp/tasks/almond/__init__.py +++ b/decanlp/tasks/almond/__init__.py @@ -46,7 +46,7 @@ class AlmondDataset(generic_dataset.CQA): base_url = None - def __init__(self, path, field, tokenize, contextual=False, reverse_task=False, subsample=None, **kwargs): + def __init__(self, path, contextual=False, reverse_task=False, subsample=None, **kwargs): cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) @@ -100,23 +100,21 @@ class AlmondDataset(generic_dataset.CQA): examples.append(generic_dataset.Example.from_raw('almond/' + _id, context, question, answer, - tokenize=str.split, lower=field.lower)) + tokenize=str.split, lower=False)) if len(examples) >= max_examples: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) - super().__init__(examples, field, **kwargs) + super().__init__(examples, **kwargs) @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) @classmethod - def splits(cls, fields, root='.data', - train='train', validation='eval', - test='test', contextual=False, **kwargs): + def splits(cls, root='.data', train='train', validation='eval', test='test', contextual=False, **kwargs): """Create dataset objects for splits of the ThingTalk dataset. Arguments: @@ -138,14 +136,14 @@ class AlmondDataset(generic_dataset.CQA): aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') - aux_data = cls(os.path.join(path, 'aux' + '.tsv'), fields, contextual=contextual, **kwargs) + aux_data = cls(os.path.join(path, 'aux' + '.tsv'), contextual=contextual, **kwargs) train_data = None if train is None else cls( - os.path.join(path, train + '.tsv'), fields, contextual=contextual, **kwargs) + os.path.join(path, train + '.tsv'), contextual=contextual, **kwargs) val_data = None if validation is None else cls( - os.path.join(path, validation + '.tsv'), fields, contextual=contextual, **kwargs) + os.path.join(path, validation + '.tsv'), contextual=contextual, **kwargs) test_data = None if test is None else cls( - os.path.join(path, test + '.tsv'), fields, contextual=contextual, **kwargs) + os.path.join(path, test + '.tsv'), contextual=contextual, **kwargs) return tuple(d for d in (train_data, val_data, test_data, aux_data) if d is not None) @@ -180,8 +178,7 @@ class Almond(BaseAlmondTask): i.e. natural language to formal language (ThingTalk) mapping""" def get_splits(self, field, root, **kwargs): - return AlmondDataset.splits( - fields=field, root=root, tokenize=self.tokenize, **kwargs) + return AlmondDataset.splits(root=root, lower=field.lower, **kwargs) @register_task('contextual_almond') diff --git a/decanlp/tasks/generic.py b/decanlp/tasks/generic.py index a9c9ecfe..01248da8 100644 --- a/decanlp/tasks/generic.py +++ b/decanlp/tasks/generic.py @@ -38,10 +38,11 @@ class Multi30K(BaseTask): def metrics(self): return ['bleu', 'em', 'nem', 'nf1'] - def get_splits(self, field, root, **kwargs): + def get_splits(self, root, **kwargs): src, trg = ['.' + x for x in self.name.split('.')[1:]] return generic_dataset.Multi30k.splits(exts=(src, trg), - fields=field, root=root, **kwargs) + root=root, + **kwargs) @register_task('iwslt') @@ -50,10 +51,11 @@ class IWSLT(BaseTask): def metrics(self): return ['bleu', 'em', 'nem', 'nf1'] - def get_splits(self, field, root, **kwargs): + def get_splits(self, root, **kwargs): src, trg = ['.' + x for x in self.name.split('.')[1:]] return generic_dataset.IWSLT.splits(exts=(src, trg), - fields=field, root=root, **kwargs) + root=root, + **kwargs) @register_task('squad') @@ -62,9 +64,10 @@ class SQuAD(BaseTask): def metrics(self): return ['nf1', 'em', 'nem'] - def get_splits(self, field, root, **kwargs): - return generic_dataset.SQuAD.splits( - fields=field, root=root, description=self.name, **kwargs) + def get_splits(self, root, **kwargs): + return generic_dataset.SQuAD.splits(root=root, + description=self.name, + **kwargs) @register_task('wikisql') @@ -73,19 +76,22 @@ class WikiSQL(BaseTask): def metrics(self): return ['lfem', 'em', 'nem', 'nf1'] - def get_splits(self, field, root, **kwargs): + def get_splits(self, root, **kwargs): return generic_dataset.WikiSQL.splits( - fields=field, root=root, query_as_question='query_as_question' in self.name, **kwargs) + root=root, + query_as_question='query_as_question' in self.name, + **kwargs) @register_task('ontonotes') class OntoNotesNER(BaseTask): - def get_splits(self, field, root, **kwargs): + def get_splits(self, root, **kwargs): split_task = self.name.split('.') _, _, subtask, nones, counting = split_task return generic_dataset.OntoNotesNER.splits( subtask=subtask, nones=True if nones == 'nones' else False, - fields=field, root=root, **kwargs) + root=root, + **kwargs) @register_task('woz') @@ -94,16 +100,19 @@ class WoZ(BaseTask): def metrics(self): return ['joint_goal_em', 'turn_request_em', 'turn_goal_em', 'avg_dialogue', 'em', 'nem', 'nf1'] - def get_splits(self, field, root, **kwargs): + def get_splits(self, root, **kwargs): return generic_dataset.WOZ.splits(description=self.name, - fields=field, root=root, **kwargs) + root=root, + **kwargs) @register_task('multinli') class MultiNLI(BaseTask): - def get_splits(self, field, root, **kwargs): + def get_splits(self, root, **kwargs): return generic_dataset.MultiNLI.splits(description=self.name, - fields=field, root=root, **kwargs) + root=root, + **kwargs) + @register_task('srl') class SRL(BaseTask): @@ -111,20 +120,20 @@ class SRL(BaseTask): def metrics(self): return ['nf1', 'em', 'nem'] - def get_splits(self, field, root, **kwargs): - return generic_dataset.SRL.splits(fields=field, root=root, **kwargs) + def get_splits(self, root, **kwargs): + return generic_dataset.SRL.splits(root=root, **kwargs) @register_task('snli') class SNLI(BaseTask): - def get_splits(self, field, root, **kwargs): - return generic_dataset.SNLI.splits(fields=field, root=root, **kwargs) + def get_splits(self, root, **kwargs): + return generic_dataset.SNLI.splits(root=root, **kwargs) @register_task('schema') class WinogradSchema(BaseTask): - def get_splits(self, field, root, **kwargs): - return generic_dataset.WinogradSchema.splits(fields=field, root=root, **kwargs) + def get_splits(self, root, **kwargs): + return generic_dataset.WinogradSchema.splits(root=root, **kwargs) class BaseSummarizationTask(BaseTask): @@ -142,23 +151,21 @@ class BaseSummarizationTask(BaseTask): @register_task('cnn') class CNN(BaseSummarizationTask): - def get_splits(self, field, root, **kwargs): - return generic_dataset.CNN.splits(fields=field, root=root, **kwargs) + def get_splits(self, root, **kwargs): + return generic_dataset.CNN.splits(root=root, **kwargs) @register_task('dailymail') class DailyMail(BaseSummarizationTask): - def get_splits(self, field, root, **kwargs): - return generic_dataset.DailyMail.splits(fields=field, root=root, **kwargs) + def get_splits(self, root, **kwargs): + return generic_dataset.DailyMail.splits(root=root, **kwargs) @register_task('cnn_dailymail') class CNNDailyMail(BaseSummarizationTask): - def get_splits(self, field, root, **kwargs): - split_cnn = generic_dataset.CNN.splits( - fields=field, root=root, **kwargs) - split_dm = generic_dataset.DailyMail.splits( - fields=field, root=root, **kwargs) + def get_splits(self, root, **kwargs): + split_cnn = generic_dataset.CNN.splits(root=root, **kwargs) + split_dm = generic_dataset.DailyMail.splits(root=root, **kwargs) for scnn, sdm in zip(split_cnn, split_dm): scnn.examples.extend(sdm) return split_cnn @@ -166,9 +173,8 @@ class CNNDailyMail(BaseSummarizationTask): @register_task('sst') class SST(BaseTask): - def get_splits(self, field, root, **kwargs): - return generic_dataset.SST.splits( - fields=field, root=root, **kwargs) + def get_splits(self, root, **kwargs): + return generic_dataset.SST.splits(root=root, **kwargs) @register_task('imdb') @@ -176,9 +182,9 @@ class IMDB(BaseTask): def preprocess_example(self, ex, train=False, max_context_length=None): return ex._replace(context=ex.context[:max_context_length]) - def get_splits(self, field, root, **kwargs): + def get_splits(self, root, **kwargs): kwargs['validation'] = None - return generic_dataset.IMDb.splits(fields=field, root=root, **kwargs) + return generic_dataset.IMDb.splits(root=root, **kwargs) @register_task('zre') @@ -187,5 +193,5 @@ class ZRE(BaseTask): def metrics(self): return ['corpus_f1', 'precision', 'recall', 'em', 'nem', 'nf1'] - def get_splits(self, field, root, **kwargs): - return generic_dataset.ZeroShotRE.splits(fields=field, root=root, **kwargs) + def get_splits(self, root, **kwargs): + return generic_dataset.ZeroShotRE.splits(root=root, **kwargs) diff --git a/decanlp/tasks/generic_dataset.py b/decanlp/tasks/generic_dataset.py index e8b1a119..10b47d35 100644 --- a/decanlp/tasks/generic_dataset.py +++ b/decanlp/tasks/generic_dataset.py @@ -51,10 +51,6 @@ def make_example_id(dataset, example_id): class CQA(data.Dataset): - def __init__(self, examples, field, **kwargs): - fields = [(x, field) for x in Example._fields] - super().__init__(examples, fields, **kwargs) - @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) @@ -69,7 +65,7 @@ class IMDb(CQA): def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) - def __init__(self, path, field, subsample=None, **kwargs): + def __init__(self, path, subsample=None, tokenize=None, lower=False, **kwargs): examples = [] labels = {'neg': 'negative', 'pos': 'positive'} question = 'Is this review negative or positive?' @@ -88,30 +84,29 @@ class IMDb(CQA): answer = labels[label] examples.append(Example.from_raw(make_example_id(self, len(examples)), context, question, answer, - tokenize=field.tokenize, lower=field.lower)) + tokenize=tokenize, lower=lower)) if subsample is not None and len(examples) > subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) - super().__init__(examples, field, **kwargs) + super().__init__(examples, **kwargs) @classmethod - def splits(cls, fields, root='.data', - train='train', validation=None, test='test', **kwargs): + def splits(cls, root='.data', train='train', validation=None, test='test', **kwargs): assert validation is None path = cls.download(root) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') - aux_data = cls(os.path.join(path, 'aux'), fields, **kwargs) + aux_data = cls(os.path.join(path, 'aux'), **kwargs) train_data = None if train is None else cls( - os.path.join(path, f'{train}'), fields, **kwargs) + os.path.join(path, f'{train}'), **kwargs) test_data = None if test is None else cls( - os.path.join(path, f'{test}'), fields, **kwargs) + os.path.join(path, f'{test}'), **kwargs) return tuple(d for d in (train_data, test_data, aux_data) if d is not None) @@ -128,7 +123,7 @@ class SST(CQA): def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) - def __init__(self, path, field, subsample=None, **kwargs): + def __init__(self, path, subsample=None, tokenize=None, lower=False, **kwargs): cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) @@ -149,7 +144,7 @@ class SST(CQA): answer = labels[int(parsed[0])] examples.append(Example.from_raw(make_example_id(self, len(examples)), context, question, answer, - tokenize=field.tokenize, lower=field.lower)) + tokenize=tokenize, lower=lower)) if subsample is not None and len(examples) > subsample: break @@ -159,25 +154,24 @@ class SST(CQA): torch.save(examples, cache_name) self.examples = examples - super().__init__(examples, field, **kwargs) + super().__init__(examples, **kwargs) @classmethod - def splits(cls, fields, root='.data', - train='train', validation='dev', test='test', **kwargs): + def splits(cls, root='.data', train='train', validation='dev', test='test', **kwargs): path = cls.download(root) postfix = f'_binary_sent.csv' aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') - aux_data = cls(os.path.join(path, f'aux{postfix}'), fields, **kwargs) + aux_data = cls(os.path.join(path, f'aux{postfix}'), **kwargs) train_data = None if train is None else cls( - os.path.join(path, f'{train}{postfix}'), fields, **kwargs) + os.path.join(path, f'{train}{postfix}'), **kwargs) validation_data = None if validation is None else cls( - os.path.join(path, f'{validation}{postfix}'), fields, **kwargs) + os.path.join(path, f'{validation}{postfix}'), **kwargs) test_data = None if test is None else cls( - os.path.join(path, f'{test}{postfix}'), fields, **kwargs) + os.path.join(path, f'{test}{postfix}'), **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) @@ -188,7 +182,7 @@ class TranslationDataset(CQA): def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) - def __init__(self, path, exts, field, subsample=None, **kwargs): + def __init__(self, path, exts, subsample=None, tokenize=None, lower=False, **kwargs): """Create a TranslationDataset given paths and fields. Arguments: @@ -220,7 +214,7 @@ class TranslationDataset(CQA): answer = trg_line examples.append(Example.from_raw(make_example_id(self, len(examples)), context, question, answer, - tokenize=field.tokenize, lower=field.lower)) + tokenize=tokenize, lower=lower)) if subsample is not None and len(examples) >= subsample: break @@ -228,11 +222,10 @@ class TranslationDataset(CQA): os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) - super().__init__(examples, field, **kwargs) + super().__init__(examples, **kwargs) @classmethod - def splits(cls, exts, fields, root='.data', - train='train', validation='val', test='test', **kwargs): + def splits(cls, exts, root='.data', train='train', validation='val', test='test', **kwargs): """Create dataset objects for splits of a TranslationDataset. Arguments: @@ -252,14 +245,14 @@ class TranslationDataset(CQA): aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') - aux_data = cls(os.path.join(path, 'aux'), exts, fields, **kwargs) + aux_data = cls(os.path.join(path, 'aux'), exts, **kwargs) train_data = None if train is None else cls( - os.path.join(path, train), exts, fields, **kwargs) + os.path.join(path, train), exts, **kwargs) val_data = None if validation is None else cls( - os.path.join(path, validation), exts, fields, **kwargs) + os.path.join(path, validation), exts, **kwargs) test_data = None if test is None else cls( - os.path.join(path, test), exts, fields, **kwargs) + os.path.join(path, test), exts, **kwargs) return tuple(d for d in (train_data, val_data, test_data, aux_data) if d is not None) @@ -279,7 +272,7 @@ class IWSLT(TranslationDataset, CQA): base_dirname = '{}-{}' @classmethod - def splits(cls, exts, fields, root='.data', + def splits(cls, exts, root='.data', train='train', validation='IWSLT16.TED.tst2013', test='IWSLT16.TED.tst2014', **kwargs): """Create dataset objects for splits of the IWSLT dataset. @@ -305,7 +298,7 @@ class IWSLT(TranslationDataset, CQA): if kwargs.get('curriculum', False): kwargs.pop('curriculum') aux = '.'.join(['aux', cls.dirname]) - aux_data = cls(os.path.join(path, aux), exts, fields, **kwargs) + aux_data = cls(os.path.join(path, aux), exts, **kwargs) if train is not None: train = '.'.join([train, cls.dirname]) @@ -318,11 +311,11 @@ class IWSLT(TranslationDataset, CQA): cls.clean(path) train_data = None if train is None else cls( - os.path.join(path, train), exts, fields, **kwargs) + os.path.join(path, train), exts, **kwargs) val_data = None if validation is None else cls( - os.path.join(path, validation), exts, fields, **kwargs) + os.path.join(path, validation), exts, **kwargs) test_data = None if test is None else cls( - os.path.join(path, test), exts, fields, **kwargs) + os.path.join(path, test), exts, **kwargs) return tuple(d for d in (train_data, val_data, test_data, aux_data) if d is not None) @@ -349,7 +342,7 @@ class IWSLT(TranslationDataset, CQA): fd_txt.write(l.strip() + '\n') -class SQuAD(CQA, data.Dataset): +class SQuAD(CQA): @staticmethod def sort_key(ex): @@ -362,7 +355,7 @@ class SQuAD(CQA, data.Dataset): name = 'squad' dirname = '' - def __init__(self, path, field, subsample=None, **kwargs): + def __init__(self, path, subsample=None, lower=False, **kwargs): cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) @@ -389,7 +382,7 @@ class SQuAD(CQA, data.Dataset): context = ' '.join(context.split()) ex = Example.from_raw(make_example_id(self, qa['id']), context, question, answer, - tokenize=str.split, lower=field.lower) + tokenize=str.split, lower=lower) else: answer = qa['answers'][0]['text'] all_answers.append([a['text'] for a in qa['answers']]) @@ -435,8 +428,7 @@ class SQuAD(CQA, data.Dataset): indexed_answer = tagged_context[context_spans[0]:context_spans[-1]+1] if len(indexed_answer) != len(tokenized_answer): import pdb; pdb.set_trace() - if field.eos_token is not None: - context_spans += [len(tagged_context)] + context_spans += [len(tagged_context)] for context_idx, answer_word in zip(context_spans, ex.answer): if context_idx == len(tagged_context): continue @@ -445,7 +437,7 @@ class SQuAD(CQA, data.Dataset): ex = Example.from_raw(make_example_id(self, qa['id']), ' '.join(tagged_context), question, ' '.join(tokenized_answer), - tokenize=str.split, lower=field.lower) + tokenize=str.split, lower=lower) examples.append(ex) if subsample is not None and len(examples) > subsample: @@ -459,14 +451,13 @@ class SQuAD(CQA, data.Dataset): logger.info(f'Caching data to {cache_name}') torch.save((examples, all_answers, q_ids), cache_name) - super(SQuAD, self).__init__(examples, field, **kwargs) + super(SQuAD, self).__init__(examples, **kwargs) self.all_answers = all_answers self.q_ids = q_ids @classmethod - def splits(cls, fields, root='.data', description='squad1.1', - train='train', validation='dev', test=None, **kwargs): + def splits(cls, root='.data', description='squad1.1', train='train', validation='dev', test=None, **kwargs): """Create dataset objects for splits of the SQuAD dataset. Arguments: root: directory containing SQuAD data @@ -484,15 +475,15 @@ class SQuAD(CQA, data.Dataset): if kwargs.get('curriculum', False): kwargs.pop('curriculum') aux = '-'.join(['aux', extension]) - aux_data = cls(os.path.join(path, aux), fields, **kwargs) + aux_data = cls(os.path.join(path, aux), **kwargs) train = '-'.join([train, extension]) if train is not None else None validation = '-'.join([validation, extension]) if validation is not None else None train_data = None if train is None else cls( - os.path.join(path, train), fields, **kwargs) + os.path.join(path, train), **kwargs) validation_data = None if validation is None else cls( - os.path.join(path, validation), fields, **kwargs) + os.path.join(path, validation), **kwargs) return tuple(d for d in (train_data, validation_data, aux_data) if d is not None) @@ -510,13 +501,13 @@ def fix_missing_period(line): return line + "." -class Summarization(CQA, data.Dataset): +class Summarization(CQA): @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) - def __init__(self, path, field, one_answer=True, subsample=None, **kwargs): + def __init__(self, path, one_answer=True, subsample=None, tokenize=None, lower=False, **kwargs): cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) @@ -533,14 +524,14 @@ class Summarization(CQA, data.Dataset): context, question, answer = ex['context'], ex['question'], ex['answer'] examples.append(Example.from_raw(make_example_id(self, len(examples)), context, question, answer, - tokenize=field.tokenize, lower=field.lower)) + tokenize=tokenize, lower=lower)) if subsample is not None and len(examples) >= subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) - super(Summarization, self).__init__(examples, field, **kwargs) + super(Summarization, self).__init__(examples, **kwargs) @classmethod def cache_splits(cls, path): @@ -593,22 +584,21 @@ class Summarization(CQA, data.Dataset): @classmethod - def splits(cls, fields, root='.data', - train='training', validation='validation', test='test', **kwargs): + def splits(cls, root='.data', train='training', validation='validation', test='test', **kwargs): path = cls.download(root) cls.cache_splits(path) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') - aux_data = cls(os.path.join(path, 'auxiliary.jsonl'), fields, **kwargs) + aux_data = cls(os.path.join(path, 'auxiliary.jsonl'), **kwargs) train_data = None if train is None else cls( - os.path.join(path, 'training.jsonl'), fields, **kwargs) + os.path.join(path, 'training.jsonl'), **kwargs) validation_data = None if validation is None else cls( - os.path.join(path, 'validation.jsonl'), fields, one_answer=False, **kwargs) + os.path.join(path, 'validation.jsonl'), one_answer=False, **kwargs) test_data = None if test is None else cls( - os.path.join(path, 'test.jsonl'), fields, one_answer=False, **kwargs) + os.path.join(path, 'test.jsonl'), one_answer=False, **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) @@ -659,7 +649,7 @@ class Query: return cls(sel_index=d['sel'], agg_index=d['agg'], columns=t, conditions=d['conds']) -class WikiSQL(CQA, data.Dataset): +class WikiSQL(CQA): @staticmethod def sort_key(ex): @@ -669,7 +659,7 @@ class WikiSQL(CQA, data.Dataset): name = 'wikisql' dirname = 'data' - def __init__(self, path, field, query_as_question=False, subsample=None, **kwargs): + def __init__(self, path, query_as_question=False, subsample=None, tokenize=None, lower=False, **kwargs): cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', 'query_as_question' if query_as_question else 'query_as_context', os.path.basename(path), str(subsample)) skip_cache_bool = kwargs.pop('skip_cache_bool') @@ -705,7 +695,7 @@ class WikiSQL(CQA, data.Dataset): context += f'-- {human_query}' examples.append(Example.from_raw(make_example_id(self, idx), context, question, answer, - tokenize=field.tokenize, lower=field.lower)) + tokenize=tokenize, lower=lower)) all_answers.append({'sql': sql, 'header': header, 'answer': answer, 'table': table}) if subsample is not None and len(examples) > subsample: break @@ -714,13 +704,12 @@ class WikiSQL(CQA, data.Dataset): logger.info(f'Caching data to {cache_name}') torch.save((examples, all_answers), cache_name) - super(WikiSQL, self).__init__(examples, field, **kwargs) + super(WikiSQL, self).__init__(examples, **kwargs) self.all_answers = all_answers @classmethod - def splits(cls, fields, root='.data', - train='train.jsonl', validation='dev.jsonl', test='test.jsonl', **kwargs): + def splits(cls, root='.data', train='train.jsonl', validation='dev.jsonl', test='test.jsonl', **kwargs): """Create dataset objects for splits of the SQuAD dataset. Arguments: root: directory containing SQuAD data @@ -735,19 +724,19 @@ class WikiSQL(CQA, data.Dataset): aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') - aux_data = cls(os.path.join(path, 'aux'), fields, **kwargs) + aux_data = cls(os.path.join(path, 'aux'), **kwargs) train_data = None if train is None else cls( - os.path.join(path, train), fields, **kwargs) + os.path.join(path, train), **kwargs) validation_data = None if validation is None else cls( - os.path.join(path, validation), fields, **kwargs) + os.path.join(path, validation), **kwargs) test_data = None if test is None else cls( - os.path.join(path, test), fields, **kwargs) + os.path.join(path, test), **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) -class SRL(CQA, data.Dataset): +class SRL(CQA): @staticmethod def sort_key(ex): @@ -762,9 +751,10 @@ class SRL(CQA, data.Dataset): @classmethod def clean(cls, s): - closing_punctuation = set([ ' .', ' ,', ' ;', ' !', ' ?', ' :', ' )', " 'll", " n't ", " %", " 't", " 's", " 'm", " 'd", " 're"]) - opening_punctuation = set(['( ', '$ ']) - both_sides = set([' - ']) + closing_punctuation = {' .', ' ,', ' ;', ' !', ' ?', ' :', ' )', " 'll", " n't ", " %", " 't", " 's", " 'm", + " 'd", " 're"} + opening_punctuation = {'( ', '$ '} + both_sides = {' - '} s = ' '.join(s.split()).strip() s = s.replace('-LRB-', '(') s = s.replace('-RRB-', ')') @@ -787,7 +777,7 @@ class SRL(CQA, data.Dataset): s = s.replace(" '", '') return ' '.join(s.split()).strip() - def __init__(self, path, field, one_answer=True, subsample=None, **kwargs): + def __init__(self, path, one_answer=True, subsample=None, tokenize=None, lower=False, **kwargs): cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) @@ -805,7 +795,7 @@ class SRL(CQA, data.Dataset): context, question, answer = ex['context'], ex['question'], ex['answer'] examples.append(Example.from_raw(make_example_id(self, len(all_answers)), context, question, answer, - tokenize=field.tokenize, lower=field.lower)) + tokenize=tokenize, lower=lower)) all_answers.append(aa) if subsample is not None and len(examples) >= subsample: break @@ -813,7 +803,7 @@ class SRL(CQA, data.Dataset): logger.info(f'Caching data to {cache_name}') torch.save((examples, all_answers), cache_name) - super(SRL, self).__init__(examples, field, **kwargs) + super(SRL, self).__init__(examples, **kwargs) self.all_answers = all_answers @@ -909,22 +899,21 @@ class SRL(CQA, data.Dataset): @classmethod - def splits(cls, fields, root='.data', - train='train', validation='dev', test='test', **kwargs): + def splits(cls, root='.data', train='train', validation='dev', test='test', **kwargs): path = cls.download(root) cls.cache_splits(path) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') - aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs) + aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs) train_data = None if train is None else cls( - os.path.join(path, f'{train}.jsonl'), fields, **kwargs) + os.path.join(path, f'{train}.jsonl'), **kwargs) validation_data = None if validation is None else cls( - os.path.join(path, f'{validation}.jsonl'), fields, one_answer=False, **kwargs) + os.path.join(path, f'{validation}.jsonl'), one_answer=False, **kwargs) test_data = None if test is None else cls( - os.path.join(path, f'{test}.jsonl'), fields, one_answer=False, **kwargs) + os.path.join(path, f'{test}.jsonl'), one_answer=False, **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) @@ -940,7 +929,7 @@ class WinogradSchema(CQA, data.Dataset): name = 'schema' dirname = '' - def __init__(self, path, field, subsample=None, **kwargs): + def __init__(self, path, subsample=None, tokenize=None, lower=False, **kwargs): cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) skip_cache_bool = kwargs.pop('skip_cache_bool') @@ -955,14 +944,14 @@ class WinogradSchema(CQA, data.Dataset): context, question, answer = ex['context'], ex['question'], ex['answer'] examples.append(Example.from_raw(make_example_id(self, len(examples)), context, question, answer, - tokenize=field.tokenize, lower=field.lower)) + tokenize=tokenize, lower=lower)) if subsample is not None and len(examples) >= subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) - super(WinogradSchema, self).__init__(examples, field, **kwargs) + super(WinogradSchema, self).__init__(examples, **kwargs) @classmethod def cache_splits(cls, path): @@ -1021,22 +1010,21 @@ class WinogradSchema(CQA, data.Dataset): @classmethod - def splits(cls, fields, root='.data', - train='train', validation='validation', test='test', **kwargs): + def splits(cls, root='.data', train='train', validation='validation', test='test', **kwargs): path = cls.download(root) cls.cache_splits(path) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') - aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs) + aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs) train_data = None if train is None else cls( - os.path.join(path, f'{train}.jsonl'), fields, **kwargs) + os.path.join(path, f'{train}.jsonl'), **kwargs) validation_data = None if validation is None else cls( - os.path.join(path, f'{validation}.jsonl'), fields, **kwargs) + os.path.join(path, f'{validation}.jsonl'), **kwargs) test_data = None if test is None else cls( - os.path.join(path, f'{test}.jsonl'), fields, **kwargs) + os.path.join(path, f'{test}.jsonl'), **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) @@ -1058,10 +1046,11 @@ class WOZ(CQA, data.Dataset): name = 'woz' dirname = '' - def __init__(self, path, field, subsample=None, description='woz.en', **kwargs): + def __init__(self, path, subsample=None, tokenize=None, lower=False, description='woz.en', **kwargs): examples, all_answers = [], [] cached_path = kwargs.pop('cached_path') - cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample), description) + cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), + str(subsample), description) skip_cache_bool = kwargs.pop('skip_cache_bool') if os.path.exists(cache_name) and not skip_cache_bool: logger.info(f'Loading cached data from {cache_name}') @@ -1075,7 +1064,7 @@ class WOZ(CQA, data.Dataset): all_answers.append((ex['lang_dialogue_turn'], answer)) examples.append(Example.from_raw(make_example_id(self, woz_id), context, question, answer, - tokenize=field.tokenize, lower=field.lower)) + tokenize=tokenize, lower=lower)) if subsample is not None and len(examples) >= subsample: break @@ -1083,7 +1072,7 @@ class WOZ(CQA, data.Dataset): logger.info(f'Caching data to {cache_name}') torch.save((examples, all_answers), cache_name) - super(WOZ, self).__init__(examples, field, **kwargs) + super(WOZ, self).__init__(examples, **kwargs) self.all_answers = all_answers @classmethod @@ -1150,26 +1139,26 @@ class WOZ(CQA, data.Dataset): @classmethod - def splits(cls, fields, root='.data', train='train', validation='validate', test='test', **kwargs): + def splits(cls, root='.data', train='train', validation='validate', test='test', **kwargs): path = cls.download(root) cls.cache_splits(path) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') - aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs) + aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs) train_data = None if train is None else cls( - os.path.join(path, f'{train}.jsonl'), fields, **kwargs) + os.path.join(path, f'{train}.jsonl'), **kwargs) validation_data = None if validation is None else cls( - os.path.join(path, f'{validation}.jsonl'), fields, **kwargs) + os.path.join(path, f'{validation}.jsonl'), **kwargs) test_data = None if test is None else cls( - os.path.join(path, f'{test}.jsonl'), fields, **kwargs) + os.path.join(path, f'{test}.jsonl'), **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) -class MultiNLI(CQA, data.Dataset): +class MultiNLI(CQA): @staticmethod def sort_key(ex): @@ -1180,9 +1169,10 @@ class MultiNLI(CQA, data.Dataset): name = 'multinli' dirname = 'multinli_1.0' - def __init__(self, path, field, subsample=None, description='multinli.in.out', **kwargs): + def __init__(self, path, subsample=None, tokenize=None, lower=False, description='multinli.in.out', **kwargs): cached_path = kwargs.pop('cached_path') - cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample), description) + cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), + str(subsample), description) skip_cache_bool = kwargs.pop('skip_cache_bool') if os.path.exists(cache_name) and not skip_cache_bool: logger.info(f'Loading cached data from {cache_name}') @@ -1196,14 +1186,14 @@ class MultiNLI(CQA, data.Dataset): context, question, answer = ex['context'], ex['question'], ex['answer'] examples.append(Example.from_raw(make_example_id(self, len(examples)), context, question, answer, - tokenize=field.tokenize, lower=field.lower)) + tokenize=tokenize, lower=lower)) if subsample is not None and len(examples) >= subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) - super(MultiNLI, self).__init__(examples, field, **kwargs) + super(MultiNLI, self).__init__(examples, **kwargs) @classmethod def cache_splits(cls, path, train='multinli_1.0_train', validation='mulinli_1.0_dev_{}', test='test'): @@ -1234,26 +1224,26 @@ class MultiNLI(CQA, data.Dataset): @classmethod - def splits(cls, fields, root='.data', train='train', validation='validation', test='test', **kwargs): + def splits(cls, root='.data', train='train', validation='validation', test='test', **kwargs): path = cls.download(root) cls.cache_splits(path) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') - aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs) + aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs) train_data = None if train is None else cls( - os.path.join(path, f'{train}.jsonl'), fields, **kwargs) + os.path.join(path, f'{train}.jsonl'), **kwargs) validation_data = None if validation is None else cls( - os.path.join(path, f'{validation}.jsonl'), fields, **kwargs) + os.path.join(path, f'{validation}.jsonl'), **kwargs) test_data = None if test is None else cls( - os.path.join(path, f'{test}.jsonl'), fields, **kwargs) + os.path.join(path, f'{test}.jsonl'), **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) -class ZeroShotRE(CQA, data.Dataset): +class ZeroShotRE(CQA): @staticmethod def sort_key(ex): @@ -1264,9 +1254,10 @@ class ZeroShotRE(CQA, data.Dataset): name = 'zre' - def __init__(self, path, field, subsample=None, **kwargs): + def __init__(self, path, subsample=None, tokenize=None, lower=False, **kwargs): cached_path = kwargs.pop('cached_path') - cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) + cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), + str(subsample)) skip_cache_bool = kwargs.pop('skip_cache_bool') if os.path.exists(cache_name) and not skip_cache_bool: logger.info(f'Loading cached data from {cache_name}') @@ -1279,7 +1270,7 @@ class ZeroShotRE(CQA, data.Dataset): context, question, answer = ex['context'], ex['question'], ex['answer'] examples.append(Example.from_raw(make_example_id(self, len(examples)), context, question, answer, - tokenize=field.tokenize, lower=field.lower)) + tokenize=tokenize, lower=lower)) if subsample is not None and len(examples) >= subsample: break @@ -1287,7 +1278,7 @@ class ZeroShotRE(CQA, data.Dataset): logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) - super().__init__(examples, field, **kwargs) + super().__init__(examples, **kwargs) @classmethod def cache_splits(cls, path, train='train', validation='dev', test='test'): @@ -1316,21 +1307,21 @@ class ZeroShotRE(CQA, data.Dataset): @classmethod - def splits(cls, fields, root='.data', train='train', validation='dev', test='test', **kwargs): + def splits(cls, root='.data', train='train', validation='dev', test='test', **kwargs): path = cls.download(root) cls.cache_splits(path) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') - aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs) + aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs) train_data = None if train is None else cls( - os.path.join(path, f'{train}.jsonl'), fields, **kwargs) + os.path.join(path, f'{train}.jsonl'), **kwargs) validation_data = None if validation is None else cls( - os.path.join(path, f'{validation}.jsonl'), fields, **kwargs) + os.path.join(path, f'{validation}.jsonl'), **kwargs) test_data = None if test is None else cls( - os.path.join(path, f'{test}.jsonl'), fields, **kwargs) + os.path.join(path, f'{test}.jsonl'), **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) @@ -1396,7 +1387,9 @@ class OntoNotesNER(CQA, data.Dataset): return ' '.join(raw.split()).strip() - def __init__(self, path, field, one_answer=True, subsample=None, path_to_files='.data/ontonotes-release-5.0/data/files', subtask='all', nones=True, **kwargs): + def __init__(self, path, one_answer=True, subsample=None, tokenize=None, lower=False, + path_to_files='.data/ontonotes-release-5.0/data/files', + subtask='all', nones=True, **kwargs): cached_path = kwargs.pop('cached_path') cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample), subtask, str(nones)) skip_cache_bool = kwargs.pop('skip_cache_bool') @@ -1416,7 +1409,7 @@ class OntoNotesNER(CQA, data.Dataset): context, question, answer = ex['context'], ex['question'], ex['answer'] examples.append(Example.from_raw(make_example_id(self, len(examples)), context, question, answer, - tokenize=field.tokenize, lower=field.lower)) + tokenize=tokenize, lower=lower)) if subsample is not None and len(examples) >= subsample: break @@ -1424,7 +1417,7 @@ class OntoNotesNER(CQA, data.Dataset): logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) - super(OntoNotesNER, self).__init__(examples, field, **kwargs) + super(OntoNotesNER, self).__init__(examples, **kwargs) @classmethod @@ -1554,8 +1547,7 @@ class OntoNotesNER(CQA, data.Dataset): @classmethod - def splits(cls, fields, root='.data', - train='train', validation='development', test='test', **kwargs): + def splits(cls, root='.data', train='train', validation='development', test='test', **kwargs): path_to_files = os.path.join(root, 'ontonotes-release-5.0', 'data', 'files') assert os.path.exists(path_to_files) path = cls.download(root) @@ -1564,18 +1556,18 @@ class OntoNotesNER(CQA, data.Dataset): aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') - aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs) + aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs) train_data = None if train is None else cls( - os.path.join(path, f'{train}.jsonl'), fields, **kwargs) + os.path.join(path, f'{train}.jsonl'), **kwargs) validation_data = None if validation is None else cls( - os.path.join(path, f'{validation}.jsonl'), fields, one_answer=False, **kwargs) + os.path.join(path, f'{validation}.jsonl'), one_answer=False, **kwargs) test_data = None if test is None else cls( - os.path.join(path, f'{test}.jsonl'), fields, one_answer=False, **kwargs) + os.path.join(path, f'{test}.jsonl'), one_answer=False, **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) -class SNLI(CQA, data.Dataset): +class SNLI(CQA): @staticmethod def sort_key(ex): @@ -1586,9 +1578,10 @@ class SNLI(CQA, data.Dataset): name = 'snli' - def __init__(self, path, field, subsample=None, **kwargs): + def __init__(self, path, subsample=None, tokenize=None, lower=False, **kwargs): cached_path = kwargs.pop('cached_path') - cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) + cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), + str(subsample)) skip_cache_bool = kwargs.pop('skip_cache_bool') if os.path.exists(cache_name) and not skip_cache_bool: logger.info(f'Loading cached data from {cache_name}') @@ -1602,7 +1595,7 @@ class SNLI(CQA, data.Dataset): context, question, answer = ex['context'], ex['question'], ex['answer'] examples.append(Example.from_raw(make_example_id(self, len(examples)), context, question, answer, - tokenize=field.tokenize, lower=field.lower)) + tokenize=tokenize, lower=lower)) if subsample is not None and len(examples) >= subsample: break @@ -1610,7 +1603,7 @@ class SNLI(CQA, data.Dataset): logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) - super().__init__(examples, field, **kwargs) + super().__init__(examples, **kwargs) @classmethod def cache_splits(cls, path, train='train', validation='dev', test='test'): @@ -1632,34 +1625,36 @@ class SNLI(CQA, data.Dataset): @classmethod - def splits(cls, fields, root='.data', train='train', validation='dev', test='test', **kwargs): + def splits(cls, root='.data', train='train', validation='dev', test='test', **kwargs): path = cls.download(root) cls.cache_splits(path) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') - aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs) + aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs) train_data = None if train is None else cls( - os.path.join(path, f'{train}.jsonl'), fields, **kwargs) + os.path.join(path, f'{train}.jsonl'), **kwargs) validation_data = None if validation is None else cls( - os.path.join(path, f'{validation}.jsonl'), fields, **kwargs) + os.path.join(path, f'{validation}.jsonl'), **kwargs) test_data = None if test is None else cls( - os.path.join(path, f'{test}.jsonl'), fields, **kwargs) + os.path.join(path, f'{test}.jsonl'), **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) -class JSON(CQA, data.Dataset): +class JSON(CQA): + name = 'json' @staticmethod def sort_key(ex): return data.interleave_keys(len(ex.context), len(ex.answer)) - def __init__(self, path, field, subsample=None, **kwargs): + def __init__(self, path, subsample=None, tokenize=None, lower=False, **kwargs): cached_path = kwargs.pop('cached_path') - cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), str(subsample)) + cache_name = os.path.join(cached_path, os.path.dirname(path).strip("/"), '.cache', os.path.basename(path), + str(subsample)) examples = [] skip_cache_bool = kwargs.pop('skip_cache_bool') @@ -1674,30 +1669,29 @@ class JSON(CQA, data.Dataset): context, question, answer = ex['context'], ex['question'], ex['answer'] examples.append(Example.from_raw(make_example_id(self, len(examples)), context, question, answer, - tokenize=field.tokenize, lower=field.lower)) + tokenize=tokenize, lower=lower)) if subsample is not None and len(examples) >= subsample: break os.makedirs(os.path.dirname(cache_name), exist_ok=True) logger.info(f'Caching data to {cache_name}') torch.save(examples, cache_name) - super(JSON, self).__init__(examples, field, **kwargs) + super(JSON, self).__init__(examples, **kwargs) @classmethod - def splits(cls, fields, name, root='.data', - train='train', validation='val', test='test', **kwargs): + def splits(cls, name, root='.data', train='train', validation='val', test='test', **kwargs): path = os.path.join(root, name) aux_data = None if kwargs.get('curriculum', False): kwargs.pop('curriculum') - aux_data = cls(os.path.join(path, 'aux.jsonl'), fields, **kwargs) + aux_data = cls(os.path.join(path, 'aux.jsonl'), **kwargs) train_data = None if train is None else cls( - os.path.join(path, 'train.jsonl'), fields, **kwargs) + os.path.join(path, 'train.jsonl'), **kwargs) validation_data = None if validation is None else cls( - os.path.join(path, 'val.jsonl'), fields, **kwargs) + os.path.join(path, 'val.jsonl'), **kwargs) test_data = None if test is None else cls( - os.path.join(path, 'test.jsonl'), fields, **kwargs) + os.path.join(path, 'test.jsonl'), **kwargs) return tuple(d for d in (train_data, validation_data, test_data, aux_data) if d is not None) diff --git a/decanlp/text/data/dataset.py b/decanlp/text/data/dataset.py index 2a5c12ab..62c312f3 100644 --- a/decanlp/text/data/dataset.py +++ b/decanlp/text/data/dataset.py @@ -18,19 +18,14 @@ class Dataset(torch.utils.data.Dataset): fields: A dictionary containing the name of each column together with its corresponding Field object. Two columns with the same Field object will share a vocabulary. - fields (dict[str, Field]): Contains the name of each column or field, together - with the corresponding Field object. Two fields with the same Field object - will have a shared vocabulary. """ sort_key = None - def __init__(self, examples, fields, filter_pred=None): + def __init__(self, examples, filter_pred=None, **kwargs): """Create a dataset from a list of Examples and Fields. Arguments: examples: List of Examples. - fields (List(tuple(str, Field))): The Fields to use in this tuple. The - string is a field name, and the Field is the associated field. filter_pred (callable or None): Use only examples for which filter_pred(example) is True, or use all examples if None. Default is None. @@ -41,16 +36,13 @@ class Dataset(torch.utils.data.Dataset): if make_list: examples = list(examples) self.examples = examples - self.fields = dict(fields) @classmethod - def splits(cls, path=None, root='.data', train=None, validation=None, + def splits(cls, root='.data', train=None, validation=None, test=None, **kwargs): """Create Dataset objects for multiple splits of a dataset. Arguments: - path (str): Common prefix of the splits' file paths, or None to use - the result of cls.download(root). root (str): Root dataset storage directory. Default is '.data'. train (str): Suffix to add to path for the train set, or None for no train set. Default is None. @@ -65,8 +57,7 @@ class Dataset(torch.utils.data.Dataset): split_datasets (tuple(Dataset)): Datasets for train, validation, and test splits in that order, if provided. """ - if path is None: - path = cls.download(root) + path = cls.download(root) train_data = None if train is None else cls( os.path.join(path, train), **kwargs) val_data = None if validation is None else cls( @@ -89,11 +80,6 @@ class Dataset(torch.utils.data.Dataset): for x in self.examples: yield x - def __getattr__(self, attr): - if attr in self.fields: - for x in self.examples: - yield getattr(x, attr) - @classmethod def download(cls, root, check=None): """Download and unzip an online archive (.zip, .gz, or .tgz). diff --git a/decanlp/text/data/field.py b/decanlp/text/data/field.py index 386e4f1e..65a06963 100644 --- a/decanlp/text/data/field.py +++ b/decanlp/text/data/field.py @@ -244,7 +244,7 @@ class Field(RawField): return (padded, lengths) return padded - def build_vocab(self, *args, **kwargs): + def build_vocab(self, field_names, *args, **kwargs): """Construct the Vocab object for this field from one or more datasets. Arguments: @@ -259,11 +259,7 @@ class Field(RawField): counter = Counter() sources = [] for arg in args: - if hasattr(arg, 'fields'): - sources += [getattr(arg, name) for name, field in - arg.fields.items() if field is self] - else: - sources.append(arg) + sources += [getattr(ex, name) for name in field_names for ex in arg] for data in sources: for x in data: if not self.sequential: diff --git a/decanlp/train.py b/decanlp/train.py index d57f77d9..39377ba6 100644 --- a/decanlp/train.py +++ b/decanlp/train.py @@ -34,15 +34,11 @@ import math import time import sys from copy import deepcopy - import logging from pprint import pformat from logging import handlers import numpy as np -from .utils.model_utils import init_model - import torch - from tensorboardX import SummaryWriter from . import arguments @@ -53,6 +49,8 @@ from .utils.saver import Saver from .utils.embeddings import load_embeddings from .text.data import ReversibleField from .data.numericalizer import DecoderVocabulary +from .data.example import Example +from .utils.model_utils import init_model def initialize_logger(args, rank='main'): @@ -78,12 +76,9 @@ def log(rank='main'): def prepare_data(args, field, logger): - if field is None: logger.info(f'Constructing field') - FIELD = ReversibleField(batch_first=True, init_token='', eos_token='', lower=args.lower, include_lengths=True) - else: - FIELD = field + field = ReversibleField(batch_first=True, init_token='', eos_token='', lower=args.lower, include_lengths=True) train_sets, val_sets, aux_sets, vocab_sets = [], [], [], [] for task in args.train_tasks: @@ -97,7 +92,7 @@ def prepare_data(args, field, logger): kwargs['cached_path'] = args.cached logger.info(f'Adding {task.name} to training datasets') - split = task.get_splits(FIELD, args.data, **kwargs) + split = task.get_splits(args.data, tokenize=field.tokenize, lower=args.lower, **kwargs) if args.use_curriculum: assert len(split) == 2 aux_sets.append(split[1]) @@ -118,7 +113,7 @@ def prepare_data(args, field, logger): kwargs['cached_path'] = args.cached logger.info(f'Adding {task.name} to validation datasets') - split = task.get_splits(FIELD, args.data, **kwargs) + split = task.get_splits(args.data, tokenize=field.tokenize, lower=args.lower, **kwargs) assert len(split) == 1 logger.info(f'{task.name} has {len(split[0])} validation examples') val_sets.append(split[0]) @@ -129,23 +124,23 @@ def prepare_data(args, field, logger): vectors = load_embeddings(args, logger) vocab_sets = (train_sets + val_sets) if len(vocab_sets) == 0 else vocab_sets logger.info(f'Building vocabulary') - FIELD.build_vocab(*vocab_sets, max_size=args.max_effective_vocab, vectors=vectors) + field.build_vocab(Example.vocab_fields, *vocab_sets, max_size=args.max_effective_vocab, vectors=vectors) - FIELD.decoder_vocab = DecoderVocabulary(FIELD.vocab.itos[:args.max_generative_vocab], FIELD.vocab) + field.decoder_vocab = DecoderVocabulary(field.vocab.itos[:args.max_generative_vocab], field.vocab) - logger.info(f'Vocabulary has {len(FIELD.vocab)} tokens') + logger.info(f'Vocabulary has {len(field.vocab)} tokens') logger.debug(f'The first 200 tokens:') - logger.debug(FIELD.vocab.itos[:200]) + logger.debug(field.vocab.itos[:200]) if args.use_curriculum: logger.info('Preprocessing auxiliary data for curriculum') - preprocess_examples(args, args.train_tasks, aux_sets, FIELD, logger, train=True) + preprocess_examples(args, args.train_tasks, aux_sets, field, logger, train=True) logger.info('Preprocessing training data') - preprocess_examples(args, args.train_tasks, train_sets, FIELD, logger, train=True) + preprocess_examples(args, args.train_tasks, train_sets, field, logger, train=True) logger.info('Preprocessing validation data') - preprocess_examples(args, args.val_tasks, val_sets, FIELD, logger, train=args.val_filter) + preprocess_examples(args, args.val_tasks, val_sets, field, logger, train=args.val_filter) - return FIELD, train_sets, val_sets, aux_sets + return field, train_sets, val_sets, aux_sets def get_learning_rate(i, args):