fix shuffle + cap number of paired examples
This commit is contained in:
parent
6c7f14a34b
commit
fb7f5fe979
|
@ -98,6 +98,8 @@ def parse_argv(parser):
|
|||
parser.add_argument('--paired', action='store_true',
|
||||
help='Pair related examples before numericalizing the input (e.g. training with synthetic and paraphrase '
|
||||
'sentence pairs for almond task)')
|
||||
parser.add_argument('--max_pairs', type=int, default=1000000,
|
||||
help='Maximum number of pairs to make for each example group')
|
||||
|
||||
parser.add_argument('--sentence_batching', action='store_true',
|
||||
help='Batch same sentences together (used for multilingual tasks)')
|
||||
|
@ -248,7 +250,6 @@ def post_parse(args):
|
|||
args.sentence_batching = True
|
||||
|
||||
|
||||
|
||||
args.train_batch_values = args.train_batch_tokens
|
||||
if len(args.train_task_names) > 1:
|
||||
if args.train_iterations is None:
|
||||
|
@ -262,8 +263,11 @@ def post_parse(args):
|
|||
if args.sentence_batching:
|
||||
args.train_batch_values[i] = args.train_batch_size
|
||||
if args.paired:
|
||||
num_train_langs = len(args.train_languages.split('+'))
|
||||
new_batch_size = int(args.train_batch_size * \
|
||||
(1 + min(num_train_langs**2 - num_train_langs, args.max_pairs) / num_train_langs))
|
||||
logger.warning('Using paired example training will increase effective batch size from {} to {}'.
|
||||
format(args.train_batch_size, args.train_batch_size*len(args.train_languages)))
|
||||
format(args.train_batch_size, new_batch_size))
|
||||
|
||||
if len(args.val_batch_size) < len(args.val_task_names):
|
||||
args.val_batch_size = len(args.val_task_names) * args.val_batch_size
|
||||
|
|
|
@ -70,7 +70,7 @@ class Batch(NamedTuple):
|
|||
decoder_vocab: object
|
||||
|
||||
@staticmethod
|
||||
def from_examples(examples, numericalizer, device=None, paired=False, groups=None):
|
||||
def from_examples(examples, numericalizer, device=None, paired=False, max_pairs=None, groups=None):
|
||||
assert all(isinstance(ex.example_id, str) for ex in examples)
|
||||
|
||||
decoder_vocab = numericalizer.decoder_vocab.clone()
|
||||
|
@ -78,18 +78,22 @@ class Batch(NamedTuple):
|
|||
|
||||
if paired:
|
||||
example_pairs = []
|
||||
|
||||
# get all possible combinations of related example pairs
|
||||
for i in range(0, len(examples), groups):
|
||||
related_examples = [examples[j] for j in range(i, i+groups)]
|
||||
example_pairs.extend(itertools.product(related_examples, related_examples))
|
||||
# filter out pairs of same sentences
|
||||
# filter out pairs with same sentences
|
||||
example_pairs = [ex_pair for ex_pair in example_pairs if ex_pair[0] != ex_pair[1]]
|
||||
|
||||
# shuffle example orders (note we only do pairing during training)
|
||||
example_ids = random.shuffle([ex_a.example_id + '@' + ex_b.example_id for ex_a, ex_b in example_pairs])
|
||||
context_inputs = random.shuffle([((ex_a.context, ex_a.context_word_mask), (ex_b.context, ex_b.context_word_mask)) for ex_a, ex_b in example_pairs])
|
||||
question_inputs = random.shuffle([((ex_a.question, ex_a.question_word_mask), (ex_b.question, ex_b.question_word_mask)) for ex_a, ex_b in example_pairs])
|
||||
answer_inputs = random.shuffle([((ex_a.answer, ex_a.answer_word_mask), (ex_b.answer, ex_b.answer_word_mask)) for ex_a, ex_b in example_pairs])
|
||||
# shuffle example orders and select first max_pairs of them
|
||||
random.shuffle(example_pairs)
|
||||
example_pairs = example_pairs[:max_pairs]
|
||||
|
||||
example_ids = [ex_a.example_id + '@' + ex_b.example_id for ex_a, ex_b in example_pairs]
|
||||
context_inputs = [((ex_a.context, ex_a.context_word_mask), (ex_b.context, ex_b.context_word_mask)) for ex_a, ex_b in example_pairs]
|
||||
question_inputs = [((ex_a.question, ex_a.question_word_mask), (ex_b.question, ex_b.question_word_mask)) for ex_a, ex_b in example_pairs]
|
||||
answer_inputs = [((ex_a.answer, ex_a.answer_word_mask), (ex_b.answer, ex_b.answer_word_mask)) for ex_a, ex_b in example_pairs]
|
||||
|
||||
all_example_ids_pair = example_ids
|
||||
all_context_inputs_pair = numericalizer.encode_pair(context_inputs, decoder_vocab, device=device)
|
||||
|
|
|
@ -32,7 +32,7 @@ import torch
|
|||
import random
|
||||
|
||||
from .example import Batch
|
||||
from ..tasks.generic_dataset import context_answer_len, processed_id, default_batch_fn
|
||||
from ..tasks.generic_dataset import context_answer_len, default_batch_fn
|
||||
|
||||
|
||||
class Iterator(torch.utils.data.IterableDataset):
|
||||
|
@ -80,9 +80,9 @@ class Iterator(torch.utils.data.IterableDataset):
|
|||
dataset = self.dataset
|
||||
|
||||
if self.use_data_sort_key:
|
||||
if self.sort_key == processed_id:
|
||||
if self.groups:
|
||||
batches = self._sentence_batching(dataset)
|
||||
elif self.sort_key == context_answer_len:
|
||||
else:
|
||||
batches = self._bucket_batching(dataset)
|
||||
else:
|
||||
batches = self._batch(dataset, self.batch_size)
|
||||
|
|
|
@ -35,7 +35,7 @@ from collections import defaultdict
|
|||
|
||||
from ..base_task import BaseTask
|
||||
from ..registry import register_task
|
||||
from ..generic_dataset import CQA, processed_id, context_answer_len, token_batch_fn, default_batch_fn
|
||||
from ..generic_dataset import CQA, context_answer_len, token_batch_fn, default_batch_fn
|
||||
from ...data_utils.example import Example
|
||||
|
||||
from ..base_dataset import Split
|
||||
|
@ -108,7 +108,18 @@ class AlmondDataset(CQA):
|
|||
|
||||
|
||||
def is_entity(token):
|
||||
return token[0].isupper()
|
||||
try:
|
||||
return token[0].isupper()
|
||||
except:
|
||||
print('here')
|
||||
|
||||
def process_id(ex):
|
||||
id_ = ex.example_id.rsplit('/', 1)
|
||||
id_ = id_[0] if len(id_) == 1 else id_[1]
|
||||
# translated
|
||||
if id_[0] == 'T':
|
||||
id_ = id_[1:]
|
||||
return id_
|
||||
|
||||
|
||||
class BaseAlmondTask(BaseTask):
|
||||
|
@ -334,7 +345,7 @@ class AlmondMultiLingual(BaseAlmondTask):
|
|||
def get_train_processed_ids(self, split):
|
||||
all_ids = []
|
||||
for ex in split.examples:
|
||||
all_ids.append(processed_id(ex))
|
||||
all_ids.append(process_id(ex))
|
||||
return all_ids
|
||||
|
||||
def get_splits(self, root, **kwargs):
|
||||
|
@ -360,7 +371,7 @@ class AlmondMultiLingual(BaseAlmondTask):
|
|||
for id_set in ids_sets:
|
||||
assert set(id_set) == id_set_base, 'When using sentence batching your datasets should have matching ids'
|
||||
|
||||
sort_key_fn = processed_id
|
||||
sort_key_fn = process_id
|
||||
batch_size_fn = default_batch_fn
|
||||
else:
|
||||
sort_key_fn = context_answer_len
|
||||
|
|
|
@ -53,15 +53,9 @@ def make_example_id(dataset, example_id):
|
|||
def context_answer_len(ex):
|
||||
return interleave_keys(len(ex.context), len(ex.answer))
|
||||
|
||||
def processed_id(ex):
|
||||
def id_value(ex):
|
||||
id_ = ex.example_id.rsplit('/', 1)
|
||||
id_ = id_[0] if len(id_) == 1 else id_[1]
|
||||
# translated
|
||||
if id_[0] == 'T':
|
||||
id_ = id_[1:]
|
||||
# paraphrased
|
||||
if id_[0] == 'P':
|
||||
id_ = id_[1:]
|
||||
return id_
|
||||
|
||||
# batch_size funcs
|
||||
|
|
|
@ -341,7 +341,7 @@ def train(args, devices, model, opt, lr_scheduler, train_sets, train_iterations,
|
|||
|
||||
logger.info(f'Preparing iterators')
|
||||
main_device = devices[0]
|
||||
train_iters = [(task, make_data_loader(x, numericalizer, tok, main_device, paired=args.paired, train=True))
|
||||
train_iters = [(task, make_data_loader(x, numericalizer, tok, main_device, paired=args.paired, max_pairs=args.max_pairs, train=True))
|
||||
for task, x, tok in zip(args.train_tasks, train_sets, args.train_batch_values)]
|
||||
train_iters = [(task, iter(train_iter)) for task, train_iter in train_iters]
|
||||
|
||||
|
|
|
@ -238,7 +238,7 @@ def elapsed_time(log):
|
|||
return f'{day:02}:{hour:02}:{minutes:02}:{seconds:02}'
|
||||
|
||||
|
||||
def make_data_loader(dataset, numericalizer, batch_size, device=None, paired=False, train=False, valid=False):
|
||||
def make_data_loader(dataset, numericalizer, batch_size, device=None, paired=False, max_pairs=None, train=False, valid=False):
|
||||
|
||||
iterator = Iterator(dataset,
|
||||
batch_size,
|
||||
|
@ -251,7 +251,7 @@ def make_data_loader(dataset, numericalizer, batch_size, device=None, paired=Fal
|
|||
batch_size=None,
|
||||
collate_fn=lambda minibatch: Batch.from_examples(minibatch, numericalizer,
|
||||
device=device, paired=paired and train,
|
||||
groups=iterator.groups))
|
||||
max_pairs=max_pairs, groups=iterator.groups))
|
||||
|
||||
|
||||
def pad(x, new_channel, dim, val=None):
|
||||
|
|
Loading…
Reference in New Issue