filter out same-sentence pairs

This commit is contained in:
mehrad 2020-04-13 22:19:32 -07:00
parent a77ecfdbb4
commit 6c7f14a34b
2 changed files with 10 additions and 7 deletions

View File

@ -263,7 +263,7 @@ def post_parse(args):
args.train_batch_values[i] = args.train_batch_size
if args.paired:
logger.warning('Using paired example training will increase effective batch size from {} to {}'.
format(args.train_batch_size, args.train_batch_size*(1+len(args.train_languages))))
format(args.train_batch_size, args.train_batch_size*len(args.train_languages)))
if len(args.val_batch_size) < len(args.val_task_names):
args.val_batch_size = len(args.val_task_names) * args.val_batch_size

View File

@ -29,6 +29,7 @@
from typing import NamedTuple, List
import itertools
import random
from .numericalizer.sequential_field import SequentialField
@ -78,15 +79,17 @@ class Batch(NamedTuple):
if paired:
example_pairs = []
# get all possible combinations of related example pairs
# TODO filter out pairs of same sentences
for i in range(0, len(examples), groups):
related_examples = [examples[j] for j in range(i, i+groups)]
example_pairs.extend(itertools.product(related_examples, related_examples))
example_ids = [ex_a.example_id + '@' + ex_b.example_id for ex_a, ex_b in example_pairs]
context_inputs = [((ex_a.context, ex_a.context_word_mask), (ex_b.context, ex_b.context_word_mask)) for ex_a, ex_b in example_pairs]
question_inputs = [((ex_a.question, ex_a.question_word_mask), (ex_b.question, ex_b.question_word_mask)) for ex_a, ex_b in example_pairs]
answer_inputs = [((ex_a.answer, ex_a.answer_word_mask), (ex_b.answer, ex_b.answer_word_mask)) for ex_a, ex_b in example_pairs]
# filter out pairs of same sentences
example_pairs = [ex_pair for ex_pair in example_pairs if ex_pair[0] != ex_pair[1]]
# shuffle example orders (note we only do pairing during training)
example_ids = random.shuffle([ex_a.example_id + '@' + ex_b.example_id for ex_a, ex_b in example_pairs])
context_inputs = random.shuffle([((ex_a.context, ex_a.context_word_mask), (ex_b.context, ex_b.context_word_mask)) for ex_a, ex_b in example_pairs])
question_inputs = random.shuffle([((ex_a.question, ex_a.question_word_mask), (ex_b.question, ex_b.question_word_mask)) for ex_a, ex_b in example_pairs])
answer_inputs = random.shuffle([((ex_a.answer, ex_a.answer_word_mask), (ex_b.answer, ex_b.answer_word_mask)) for ex_a, ex_b in example_pairs])
all_example_ids_pair = example_ids
all_context_inputs_pair = numericalizer.encode_pair(context_inputs, decoder_vocab, device=device)