diff --git a/spacy/util.py b/spacy/util.py index 54ecb6edd..0f8de3ddf 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -656,45 +656,47 @@ def decaying(start, stop, decay): curr -= decay -def minibatch_by_words(examples, size, tuples=True, count_words=len, tolerance=0.2): +def minibatch_by_words(examples, size, count_words=len, tolerance=0.2, discard_oversize=False): """Create minibatches of roughly a given number of words. If any examples are longer than the specified batch length, they will appear in a batch by - themselves.""" + themselves, or be discarded if discard_oversize=True.""" if isinstance(size, int): size_ = itertools.repeat(size) elif isinstance(size, List): size_ = iter(size) else: size_ = size - examples = iter(examples) - oversize = [] - while True: - batch_size = next(size_) - tol_size = batch_size * 0.2 - batch = [] - if oversize: - example = oversize.pop(0) - n_words = count_words(example.doc) + + target_size = next(size_) + tol_size = target_size * tolerance + batch = [] + current_size = 0 + + for example in examples: + n_words = count_words(example.doc) + # add the example to the current batch if it still fits + if (current_size + n_words) < (target_size + tol_size): batch.append(example) - batch_size -= n_words - while batch_size >= 1: - try: - example = next(examples) - except StopIteration: - if oversize: - example = oversize.pop(0) - batch.append(example) - if batch: - yield batch - return - n_words = count_words(example.doc) - if n_words < (batch_size + tol_size): - batch_size -= n_words - batch.append(example) + current_size += n_words + else: + # if the current example exceeds the batch size, it is returned separately + # but only if discard_oversize=False. + if current_size > target_size: + if not discard_oversize: + yield [example] + # yield the previous batch and start a new one else: - oversize.append(example) - if batch: - yield batch + yield batch + target_size = next(size_) + tol_size = target_size * tolerance + # In theory it may happen that the current example now exceeds the new target_size, + # but that seems like an unimportant edge case if batch sizes are variable anyway? + batch = [example] + current_size = n_words + + # yield the final batch + if batch: + yield batch def itershuffle(iterable, bufsize=1000):