diff --git a/genienlp/arguments.py b/genienlp/arguments.py index 74e4e326..a5be3d6a 100644 --- a/genienlp/arguments.py +++ b/genienlp/arguments.py @@ -263,7 +263,7 @@ def post_parse(args): args.train_batch_values[i] = args.train_batch_size if args.paired: logger.warning('Using paired example training will increase effective batch size from {} to {}'. - format(args.train_batch_size, args.train_batch_size*(1+len(args.train_languages)))) + format(args.train_batch_size, args.train_batch_size*len(args.train_languages))) if len(args.val_batch_size) < len(args.val_task_names): args.val_batch_size = len(args.val_task_names) * args.val_batch_size diff --git a/genienlp/data_utils/example.py b/genienlp/data_utils/example.py index 2c283c9f..2fc66ab6 100644 --- a/genienlp/data_utils/example.py +++ b/genienlp/data_utils/example.py @@ -29,6 +29,7 @@ from typing import NamedTuple, List import itertools +import random from .numericalizer.sequential_field import SequentialField @@ -78,15 +79,17 @@ class Batch(NamedTuple): if paired: example_pairs = [] # get all possible combinations of related example pairs - # TODO filter out pairs of same sentences for i in range(0, len(examples), groups): related_examples = [examples[j] for j in range(i, i+groups)] example_pairs.extend(itertools.product(related_examples, related_examples)) - - example_ids = [ex_a.example_id + '@' + ex_b.example_id for ex_a, ex_b in example_pairs] - context_inputs = [((ex_a.context, ex_a.context_word_mask), (ex_b.context, ex_b.context_word_mask)) for ex_a, ex_b in example_pairs] - question_inputs = [((ex_a.question, ex_a.question_word_mask), (ex_b.question, ex_b.question_word_mask)) for ex_a, ex_b in example_pairs] - answer_inputs = [((ex_a.answer, ex_a.answer_word_mask), (ex_b.answer, ex_b.answer_word_mask)) for ex_a, ex_b in example_pairs] + # filter out pairs of same sentences + example_pairs = [ex_pair for ex_pair in example_pairs if ex_pair[0] != ex_pair[1]] + + # shuffle example orders (note we only do pairing during training) + example_ids = random.shuffle([ex_a.example_id + '@' + ex_b.example_id for ex_a, ex_b in example_pairs]) + context_inputs = random.shuffle([((ex_a.context, ex_a.context_word_mask), (ex_b.context, ex_b.context_word_mask)) for ex_a, ex_b in example_pairs]) + question_inputs = random.shuffle([((ex_a.question, ex_a.question_word_mask), (ex_b.question, ex_b.question_word_mask)) for ex_a, ex_b in example_pairs]) + answer_inputs = random.shuffle([((ex_a.answer, ex_a.answer_word_mask), (ex_b.answer, ex_b.answer_word_mask)) for ex_a, ex_b in example_pairs]) all_example_ids_pair = example_ids all_context_inputs_pair = numericalizer.encode_pair(context_inputs, decoder_vocab, device=device)