filter out same-sentence pairs
This commit is contained in:
parent
a77ecfdbb4
commit
6c7f14a34b
|
@ -263,7 +263,7 @@ def post_parse(args):
|
|||
args.train_batch_values[i] = args.train_batch_size
|
||||
if args.paired:
|
||||
logger.warning('Using paired example training will increase effective batch size from {} to {}'.
|
||||
format(args.train_batch_size, args.train_batch_size*(1+len(args.train_languages))))
|
||||
format(args.train_batch_size, args.train_batch_size*len(args.train_languages)))
|
||||
|
||||
if len(args.val_batch_size) < len(args.val_task_names):
|
||||
args.val_batch_size = len(args.val_task_names) * args.val_batch_size
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
|
||||
from typing import NamedTuple, List
|
||||
import itertools
|
||||
import random
|
||||
|
||||
from .numericalizer.sequential_field import SequentialField
|
||||
|
||||
|
@ -78,15 +79,17 @@ class Batch(NamedTuple):
|
|||
if paired:
|
||||
example_pairs = []
|
||||
# get all possible combinations of related example pairs
|
||||
# TODO filter out pairs of same sentences
|
||||
for i in range(0, len(examples), groups):
|
||||
related_examples = [examples[j] for j in range(i, i+groups)]
|
||||
example_pairs.extend(itertools.product(related_examples, related_examples))
|
||||
|
||||
example_ids = [ex_a.example_id + '@' + ex_b.example_id for ex_a, ex_b in example_pairs]
|
||||
context_inputs = [((ex_a.context, ex_a.context_word_mask), (ex_b.context, ex_b.context_word_mask)) for ex_a, ex_b in example_pairs]
|
||||
question_inputs = [((ex_a.question, ex_a.question_word_mask), (ex_b.question, ex_b.question_word_mask)) for ex_a, ex_b in example_pairs]
|
||||
answer_inputs = [((ex_a.answer, ex_a.answer_word_mask), (ex_b.answer, ex_b.answer_word_mask)) for ex_a, ex_b in example_pairs]
|
||||
# filter out pairs of same sentences
|
||||
example_pairs = [ex_pair for ex_pair in example_pairs if ex_pair[0] != ex_pair[1]]
|
||||
|
||||
# shuffle example orders (note we only do pairing during training)
|
||||
example_ids = random.shuffle([ex_a.example_id + '@' + ex_b.example_id for ex_a, ex_b in example_pairs])
|
||||
context_inputs = random.shuffle([((ex_a.context, ex_a.context_word_mask), (ex_b.context, ex_b.context_word_mask)) for ex_a, ex_b in example_pairs])
|
||||
question_inputs = random.shuffle([((ex_a.question, ex_a.question_word_mask), (ex_b.question, ex_b.question_word_mask)) for ex_a, ex_b in example_pairs])
|
||||
answer_inputs = random.shuffle([((ex_a.answer, ex_a.answer_word_mask), (ex_b.answer, ex_b.answer_word_mask)) for ex_a, ex_b in example_pairs])
|
||||
|
||||
all_example_ids_pair = example_ids
|
||||
all_context_inputs_pair = numericalizer.encode_pair(context_inputs, decoder_vocab, device=device)
|
||||
|
|
Loading…
Reference in New Issue