From fb7f5fe97920f99605dd0740ffeeaa25faf5a5c6 Mon Sep 17 00:00:00 2001 From: mehrad Date: Thu, 16 Apr 2020 20:26:38 -0700 Subject: [PATCH] fix shuffle + cap number of paired examples --- genienlp/arguments.py | 8 ++++++-- genienlp/data_utils/example.py | 18 +++++++++++------- genienlp/data_utils/iterator.py | 6 +++--- genienlp/tasks/almond/__init__.py | 19 +++++++++++++++---- genienlp/tasks/generic_dataset.py | 8 +------- genienlp/train.py | 2 +- genienlp/util.py | 4 ++-- 7 files changed, 39 insertions(+), 26 deletions(-) diff --git a/genienlp/arguments.py b/genienlp/arguments.py index a5be3d6a..b9aaf31c 100644 --- a/genienlp/arguments.py +++ b/genienlp/arguments.py @@ -98,6 +98,8 @@ def parse_argv(parser): parser.add_argument('--paired', action='store_true', help='Pair related examples before numericalizing the input (e.g. training with synthetic and paraphrase ' 'sentence pairs for almond task)') + parser.add_argument('--max_pairs', type=int, default=1000000, + help='Maximum number of pairs to make for each example group') parser.add_argument('--sentence_batching', action='store_true', help='Batch same sentences together (used for multilingual tasks)') @@ -248,7 +250,6 @@ def post_parse(args): args.sentence_batching = True - args.train_batch_values = args.train_batch_tokens if len(args.train_task_names) > 1: if args.train_iterations is None: @@ -262,8 +263,11 @@ def post_parse(args): if args.sentence_batching: args.train_batch_values[i] = args.train_batch_size if args.paired: + num_train_langs = len(args.train_languages.split('+')) + new_batch_size = int(args.train_batch_size * \ + (1 + min(num_train_langs**2 - num_train_langs, args.max_pairs) / num_train_langs)) logger.warning('Using paired example training will increase effective batch size from {} to {}'. - format(args.train_batch_size, args.train_batch_size*len(args.train_languages))) + format(args.train_batch_size, new_batch_size)) if len(args.val_batch_size) < len(args.val_task_names): args.val_batch_size = len(args.val_task_names) * args.val_batch_size diff --git a/genienlp/data_utils/example.py b/genienlp/data_utils/example.py index 2fc66ab6..80b51b26 100644 --- a/genienlp/data_utils/example.py +++ b/genienlp/data_utils/example.py @@ -70,7 +70,7 @@ class Batch(NamedTuple): decoder_vocab: object @staticmethod - def from_examples(examples, numericalizer, device=None, paired=False, groups=None): + def from_examples(examples, numericalizer, device=None, paired=False, max_pairs=None, groups=None): assert all(isinstance(ex.example_id, str) for ex in examples) decoder_vocab = numericalizer.decoder_vocab.clone() @@ -78,18 +78,22 @@ class Batch(NamedTuple): if paired: example_pairs = [] + # get all possible combinations of related example pairs for i in range(0, len(examples), groups): related_examples = [examples[j] for j in range(i, i+groups)] example_pairs.extend(itertools.product(related_examples, related_examples)) - # filter out pairs of same sentences + # filter out pairs with same sentences example_pairs = [ex_pair for ex_pair in example_pairs if ex_pair[0] != ex_pair[1]] - # shuffle example orders (note we only do pairing during training) - example_ids = random.shuffle([ex_a.example_id + '@' + ex_b.example_id for ex_a, ex_b in example_pairs]) - context_inputs = random.shuffle([((ex_a.context, ex_a.context_word_mask), (ex_b.context, ex_b.context_word_mask)) for ex_a, ex_b in example_pairs]) - question_inputs = random.shuffle([((ex_a.question, ex_a.question_word_mask), (ex_b.question, ex_b.question_word_mask)) for ex_a, ex_b in example_pairs]) - answer_inputs = random.shuffle([((ex_a.answer, ex_a.answer_word_mask), (ex_b.answer, ex_b.answer_word_mask)) for ex_a, ex_b in example_pairs]) + # shuffle example orders and select first max_pairs of them + random.shuffle(example_pairs) + example_pairs = example_pairs[:max_pairs] + + example_ids = [ex_a.example_id + '@' + ex_b.example_id for ex_a, ex_b in example_pairs] + context_inputs = [((ex_a.context, ex_a.context_word_mask), (ex_b.context, ex_b.context_word_mask)) for ex_a, ex_b in example_pairs] + question_inputs = [((ex_a.question, ex_a.question_word_mask), (ex_b.question, ex_b.question_word_mask)) for ex_a, ex_b in example_pairs] + answer_inputs = [((ex_a.answer, ex_a.answer_word_mask), (ex_b.answer, ex_b.answer_word_mask)) for ex_a, ex_b in example_pairs] all_example_ids_pair = example_ids all_context_inputs_pair = numericalizer.encode_pair(context_inputs, decoder_vocab, device=device) diff --git a/genienlp/data_utils/iterator.py b/genienlp/data_utils/iterator.py index 77a67aa9..81b4e95e 100644 --- a/genienlp/data_utils/iterator.py +++ b/genienlp/data_utils/iterator.py @@ -32,7 +32,7 @@ import torch import random from .example import Batch -from ..tasks.generic_dataset import context_answer_len, processed_id, default_batch_fn +from ..tasks.generic_dataset import context_answer_len, default_batch_fn class Iterator(torch.utils.data.IterableDataset): @@ -80,9 +80,9 @@ class Iterator(torch.utils.data.IterableDataset): dataset = self.dataset if self.use_data_sort_key: - if self.sort_key == processed_id: + if self.groups: batches = self._sentence_batching(dataset) - elif self.sort_key == context_answer_len: + else: batches = self._bucket_batching(dataset) else: batches = self._batch(dataset, self.batch_size) diff --git a/genienlp/tasks/almond/__init__.py b/genienlp/tasks/almond/__init__.py index 6324dce6..baa74a6d 100644 --- a/genienlp/tasks/almond/__init__.py +++ b/genienlp/tasks/almond/__init__.py @@ -35,7 +35,7 @@ from collections import defaultdict from ..base_task import BaseTask from ..registry import register_task -from ..generic_dataset import CQA, processed_id, context_answer_len, token_batch_fn, default_batch_fn +from ..generic_dataset import CQA, context_answer_len, token_batch_fn, default_batch_fn from ...data_utils.example import Example from ..base_dataset import Split @@ -108,7 +108,18 @@ class AlmondDataset(CQA): def is_entity(token): - return token[0].isupper() + try: + return token[0].isupper() + except: + print('here') + +def process_id(ex): + id_ = ex.example_id.rsplit('/', 1) + id_ = id_[0] if len(id_) == 1 else id_[1] + # translated + if id_[0] == 'T': + id_ = id_[1:] + return id_ class BaseAlmondTask(BaseTask): @@ -334,7 +345,7 @@ class AlmondMultiLingual(BaseAlmondTask): def get_train_processed_ids(self, split): all_ids = [] for ex in split.examples: - all_ids.append(processed_id(ex)) + all_ids.append(process_id(ex)) return all_ids def get_splits(self, root, **kwargs): @@ -360,7 +371,7 @@ class AlmondMultiLingual(BaseAlmondTask): for id_set in ids_sets: assert set(id_set) == id_set_base, 'When using sentence batching your datasets should have matching ids' - sort_key_fn = processed_id + sort_key_fn = process_id batch_size_fn = default_batch_fn else: sort_key_fn = context_answer_len diff --git a/genienlp/tasks/generic_dataset.py b/genienlp/tasks/generic_dataset.py index fb550bf3..8aa8a6ff 100644 --- a/genienlp/tasks/generic_dataset.py +++ b/genienlp/tasks/generic_dataset.py @@ -53,15 +53,9 @@ def make_example_id(dataset, example_id): def context_answer_len(ex): return interleave_keys(len(ex.context), len(ex.answer)) -def processed_id(ex): +def id_value(ex): id_ = ex.example_id.rsplit('/', 1) id_ = id_[0] if len(id_) == 1 else id_[1] - # translated - if id_[0] == 'T': - id_ = id_[1:] - # paraphrased - if id_[0] == 'P': - id_ = id_[1:] return id_ # batch_size funcs diff --git a/genienlp/train.py b/genienlp/train.py index 18a1ff85..0da6b64e 100644 --- a/genienlp/train.py +++ b/genienlp/train.py @@ -341,7 +341,7 @@ def train(args, devices, model, opt, lr_scheduler, train_sets, train_iterations, logger.info(f'Preparing iterators') main_device = devices[0] - train_iters = [(task, make_data_loader(x, numericalizer, tok, main_device, paired=args.paired, train=True)) + train_iters = [(task, make_data_loader(x, numericalizer, tok, main_device, paired=args.paired, max_pairs=args.max_pairs, train=True)) for task, x, tok in zip(args.train_tasks, train_sets, args.train_batch_values)] train_iters = [(task, iter(train_iter)) for task, train_iter in train_iters] diff --git a/genienlp/util.py b/genienlp/util.py index 6700d515..f53aac99 100644 --- a/genienlp/util.py +++ b/genienlp/util.py @@ -238,7 +238,7 @@ def elapsed_time(log): return f'{day:02}:{hour:02}:{minutes:02}:{seconds:02}' -def make_data_loader(dataset, numericalizer, batch_size, device=None, paired=False, train=False, valid=False): +def make_data_loader(dataset, numericalizer, batch_size, device=None, paired=False, max_pairs=None, train=False, valid=False): iterator = Iterator(dataset, batch_size, @@ -251,7 +251,7 @@ def make_data_loader(dataset, numericalizer, batch_size, device=None, paired=Fal batch_size=None, collate_fn=lambda minibatch: Batch.from_examples(minibatch, numericalizer, device=device, paired=paired and train, - groups=iterator.groups)) + max_pairs=max_pairs, groups=iterator.groups)) def pad(x, new_channel, dim, val=None):