rewrite minibatch_by_words function

This commit is contained in:
svlandeg 2020-06-02 15:22:54 +02:00
parent ec52e7f886
commit fdfd822936
1 changed files with 31 additions and 29 deletions

View File

@ -656,45 +656,47 @@ def decaying(start, stop, decay):
curr -= decay curr -= decay
def minibatch_by_words(examples, size, tuples=True, count_words=len, tolerance=0.2): def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_oversize=False):
"""Create minibatches of roughly a given number of words. If any examples """Create minibatches of roughly a given number of words. If any examples
are longer than the specified batch length, they will appear in a batch by are longer than the specified batch length, they will appear in a batch by
themselves.""" themselves, or be discarded if discard_oversize=True."""
if isinstance(size, int): if isinstance(size, int):
size_ = itertools.repeat(size) size_ = itertools.repeat(size)
elif isinstance(size, List): elif isinstance(size, List):
size_ = iter(size) size_ = iter(size)
else: else:
size_ = size size_ = size
examples = iter(examples)
oversize = [] target_size = next(size_)
while True: tol_size = target_size * tolerance
batch_size = next(size_) batch = []
tol_size = batch_size * 0.2 current_size = 0
batch = []
if oversize: for example in examples:
example = oversize.pop(0) n_words = count_words(example.doc)
n_words = count_words(example.doc) # add the example to the current batch if it still fits
if (current_size + n_words) < (target_size + tol_size):
batch.append(example) batch.append(example)
batch_size -= n_words current_size += n_words
while batch_size >= 1: else:
try: # if the current example exceeds the batch size, it is returned separately
example = next(examples) # but only if discard_oversize=False.
except StopIteration: if current_size > target_size:
if oversize: if not discard_oversize:
example = oversize.pop(0) yield [example]
batch.append(example) # yield the previous batch and start a new one
if batch:
yield batch
return
n_words = count_words(example.doc)
if n_words < (batch_size + tol_size):
batch_size -= n_words
batch.append(example)
else: else:
oversize.append(example) yield batch
if batch: target_size = next(size_)
yield batch tol_size = target_size * tolerance
# In theory it may happen that the current example now exceeds the new target_size,
# but that seems like an unimportant edge case if batch sizes are variable anyway?
batch = [example]
current_size = n_words
# yield the final batch
if batch:
yield batch
def itershuffle(iterable, bufsize=1000): def itershuffle(iterable, bufsize=1000):