fix shuffle + cap number of paired examples

This commit is contained in:
mehrad 2020-04-16 20:26:38 -07:00
parent 6c7f14a34b
commit fb7f5fe979
7 changed files with 39 additions and 26 deletions

View File

@ -98,6 +98,8 @@ def parse_argv(parser):
parser.add_argument('--paired', action='store_true',
help='Pair related examples before numericalizing the input (e.g. training with synthetic and paraphrase '
'sentence pairs for almond task)')
parser.add_argument('--max_pairs', type=int, default=1000000,
help='Maximum number of pairs to make for each example group')
parser.add_argument('--sentence_batching', action='store_true',
help='Batch same sentences together (used for multilingual tasks)')
@ -248,7 +250,6 @@ def post_parse(args):
args.sentence_batching = True
args.train_batch_values = args.train_batch_tokens
if len(args.train_task_names) > 1:
if args.train_iterations is None:
@ -262,8 +263,11 @@ def post_parse(args):
if args.sentence_batching:
args.train_batch_values[i] = args.train_batch_size
if args.paired:
num_train_langs = len(args.train_languages.split('+'))
new_batch_size = int(args.train_batch_size * \
(1 + min(num_train_langs**2 - num_train_langs, args.max_pairs) / num_train_langs))
logger.warning('Using paired example training will increase effective batch size from {} to {}'.
format(args.train_batch_size, args.train_batch_size*len(args.train_languages)))
format(args.train_batch_size, new_batch_size))
if len(args.val_batch_size) < len(args.val_task_names):
args.val_batch_size = len(args.val_task_names) * args.val_batch_size

View File

@ -70,7 +70,7 @@ class Batch(NamedTuple):
decoder_vocab: object
@staticmethod
def from_examples(examples, numericalizer, device=None, paired=False, groups=None):
def from_examples(examples, numericalizer, device=None, paired=False, max_pairs=None, groups=None):
assert all(isinstance(ex.example_id, str) for ex in examples)
decoder_vocab = numericalizer.decoder_vocab.clone()
@ -78,18 +78,22 @@ class Batch(NamedTuple):
if paired:
example_pairs = []
# get all possible combinations of related example pairs
for i in range(0, len(examples), groups):
related_examples = [examples[j] for j in range(i, i+groups)]
example_pairs.extend(itertools.product(related_examples, related_examples))
# filter out pairs of same sentences
# filter out pairs with same sentences
example_pairs = [ex_pair for ex_pair in example_pairs if ex_pair[0] != ex_pair[1]]
# shuffle example orders (note we only do pairing during training)
example_ids = random.shuffle([ex_a.example_id + '@' + ex_b.example_id for ex_a, ex_b in example_pairs])
context_inputs = random.shuffle([((ex_a.context, ex_a.context_word_mask), (ex_b.context, ex_b.context_word_mask)) for ex_a, ex_b in example_pairs])
question_inputs = random.shuffle([((ex_a.question, ex_a.question_word_mask), (ex_b.question, ex_b.question_word_mask)) for ex_a, ex_b in example_pairs])
answer_inputs = random.shuffle([((ex_a.answer, ex_a.answer_word_mask), (ex_b.answer, ex_b.answer_word_mask)) for ex_a, ex_b in example_pairs])
# shuffle example orders and select first max_pairs of them
random.shuffle(example_pairs)
example_pairs = example_pairs[:max_pairs]
example_ids = [ex_a.example_id + '@' + ex_b.example_id for ex_a, ex_b in example_pairs]
context_inputs = [((ex_a.context, ex_a.context_word_mask), (ex_b.context, ex_b.context_word_mask)) for ex_a, ex_b in example_pairs]
question_inputs = [((ex_a.question, ex_a.question_word_mask), (ex_b.question, ex_b.question_word_mask)) for ex_a, ex_b in example_pairs]
answer_inputs = [((ex_a.answer, ex_a.answer_word_mask), (ex_b.answer, ex_b.answer_word_mask)) for ex_a, ex_b in example_pairs]
all_example_ids_pair = example_ids
all_context_inputs_pair = numericalizer.encode_pair(context_inputs, decoder_vocab, device=device)

View File

@ -32,7 +32,7 @@ import torch
import random
from .example import Batch
from ..tasks.generic_dataset import context_answer_len, processed_id, default_batch_fn
from ..tasks.generic_dataset import context_answer_len, default_batch_fn
class Iterator(torch.utils.data.IterableDataset):
@ -80,9 +80,9 @@ class Iterator(torch.utils.data.IterableDataset):
dataset = self.dataset
if self.use_data_sort_key:
if self.sort_key == processed_id:
if self.groups:
batches = self._sentence_batching(dataset)
elif self.sort_key == context_answer_len:
else:
batches = self._bucket_batching(dataset)
else:
batches = self._batch(dataset, self.batch_size)

View File

@ -35,7 +35,7 @@ from collections import defaultdict
from ..base_task import BaseTask
from ..registry import register_task
from ..generic_dataset import CQA, processed_id, context_answer_len, token_batch_fn, default_batch_fn
from ..generic_dataset import CQA, context_answer_len, token_batch_fn, default_batch_fn
from ...data_utils.example import Example
from ..base_dataset import Split
@ -108,7 +108,18 @@ class AlmondDataset(CQA):
def is_entity(token):
return token[0].isupper()
try:
return token[0].isupper()
except:
print('here')
def process_id(ex):
id_ = ex.example_id.rsplit('/', 1)
id_ = id_[0] if len(id_) == 1 else id_[1]
# translated
if id_[0] == 'T':
id_ = id_[1:]
return id_
class BaseAlmondTask(BaseTask):
@ -334,7 +345,7 @@ class AlmondMultiLingual(BaseAlmondTask):
def get_train_processed_ids(self, split):
all_ids = []
for ex in split.examples:
all_ids.append(processed_id(ex))
all_ids.append(process_id(ex))
return all_ids
def get_splits(self, root, **kwargs):
@ -360,7 +371,7 @@ class AlmondMultiLingual(BaseAlmondTask):
for id_set in ids_sets:
assert set(id_set) == id_set_base, 'When using sentence batching your datasets should have matching ids'
sort_key_fn = processed_id
sort_key_fn = process_id
batch_size_fn = default_batch_fn
else:
sort_key_fn = context_answer_len

View File

@ -53,15 +53,9 @@ def make_example_id(dataset, example_id):
def context_answer_len(ex):
return interleave_keys(len(ex.context), len(ex.answer))
def processed_id(ex):
def id_value(ex):
id_ = ex.example_id.rsplit('/', 1)
id_ = id_[0] if len(id_) == 1 else id_[1]
# translated
if id_[0] == 'T':
id_ = id_[1:]
# paraphrased
if id_[0] == 'P':
id_ = id_[1:]
return id_
# batch_size funcs

View File

@ -341,7 +341,7 @@ def train(args, devices, model, opt, lr_scheduler, train_sets, train_iterations,
logger.info(f'Preparing iterators')
main_device = devices[0]
train_iters = [(task, make_data_loader(x, numericalizer, tok, main_device, paired=args.paired, train=True))
train_iters = [(task, make_data_loader(x, numericalizer, tok, main_device, paired=args.paired, max_pairs=args.max_pairs, train=True))
for task, x, tok in zip(args.train_tasks, train_sets, args.train_batch_values)]
train_iters = [(task, iter(train_iter)) for task, train_iter in train_iters]

View File

@ -238,7 +238,7 @@ def elapsed_time(log):
return f'{day:02}:{hour:02}:{minutes:02}:{seconds:02}'
def make_data_loader(dataset, numericalizer, batch_size, device=None, paired=False, train=False, valid=False):
def make_data_loader(dataset, numericalizer, batch_size, device=None, paired=False, max_pairs=None, train=False, valid=False):
iterator = Iterator(dataset,
batch_size,
@ -251,7 +251,7 @@ def make_data_loader(dataset, numericalizer, batch_size, device=None, paired=Fal
batch_size=None,
collate_fn=lambda minibatch: Batch.from_examples(minibatch, numericalizer,
device=device, paired=paired and train,
groups=iterator.groups))
max_pairs=max_pairs, groups=iterator.groups))
def pad(x, new_channel, dim, val=None):